ソートされた複数のCSVファイルを結合するスクリプト

30分プログラム、その647。ソートされた複数のCSVファイルを結合するスクリプト - pgyの日記に再チャレンジ!
いろいろがんばってみたけど、どうもうまくいかない。やっぱり、iteratorは遅延リストの代わりにはならないよ - みずぴー日記に書いた制約がきびしい。teeを使えば回避できるけど、ちょっと面倒くさい。
というわけで、iteratorをラッパするクラスを書いて、一回分だけ先読みできるようにしてみた。ちょっとおおげさな気がするけど、まあいいや。

使い方

$ python csv-merge.py data1.csv data2.csv data3.csv
0                       data210 data220 data310 data320
1       data111 data121 data211 data221 data311 data321
2       data112 data122                 data312 data322
3       data113 data123
4       data114 data124 data214 data224
5       data115 data125 data215 data225 data315 data325
6       data116 data126 data216 data226 data316 data326
7       data117 data127 data217 data227 data317 data327
8                       data218 data228 data318 data328
9                       data219 data229 data319 data329

ソースコード

#! /usr/bin/python
# -*- mode:python; coding:utf-8 -*-
#
# csv-merge.py -
#
# Copyright(C) 2009 by mzp
# Author: MIZUNO Hiroki / mzpppp at gmail dot com
# http://howdyworld.org
#
# Timestamp: 2009/08/26 21:21:46
#
# This program is free software; you can redistribute it and/or
# modify it under MIT Lincence.
#
import csv
import sys
from itertools import *

class CSV:
    def __init__(self,path):
        self.csv = csv.reader(file(path,'rb'))
        self.cache = None

        self.size = len(self.peek__()) - 1

    def field_count(self):
        return self.size

    def read(self,id):
        if self.is_empty() or int(self.peek__()[0]) != id:
            return list(repeat('',self.field_count()))
        else:
            return self.read__()[1:]

    def peek__(self):
        if self.cache == None:
            self.cache = self.read__()
        return self.cache

    def read__(self):
        if self.cache != None:
            cache = self.cache
            self.cache = None
            return cache
        else:
            csv = self.csv.next()
            return csv

    def is_empty(self):
        try:
            self.peek__()
            return False
        except StopIteration:
            return True

if __name__ == '__main__':
    csv_list = map(lambda path: CSV(path),sys.argv[1:])

    for i in count(0):
        xs = map(lambda c: c.read(i), csv_list)
        print "%d\t" % i,
        print "\t".join(reduce(lambda x,y: x + y, xs))

        if all(map(lambda c: c.is_empty(),csv_list)):
            break