chainer / chainermn

ChainerMN: Scalable distributed deep learning with Chainer
https://chainer.org
MIT License
207 stars 57 forks source link

Can't pickle Transaction objects #129

Closed Aixile closed 6 years ago

Aixile commented 6 years ago
  File "train.py", line 132, in main
    train_dataset = chainermn.scatter_dataset(train_dataset, comm)
  File "/home/aixile/anaconda3/lib/python3.6/site-packages/chainermn/datasets/scatter_dataset.py", line 91, in scatter_dataset
    comm.send(subds, dest=i)
  File "MPI/Comm.pyx", line 1175, in mpi4py.MPI.Comm.send (src/mpi4py.MPI.c:106424)
  File "MPI/msgpickle.pxi", line 210, in mpi4py.MPI.PyMPI_send (src/mpi4py.MPI.c:42085)
  File "MPI/msgpickle.pxi", line 112, in mpi4py.MPI.Pickle.dump (src/mpi4py.MPI.c:40704)
TypeError: can't pickle Transaction objects
^C^Z[warn] Epoll ADD(4) on fd 28 failed.  Old events were 0; read change was 0 (none); write change was 1 (add): Bad file descriptor

This happens when my dataset loader tries to load images from a lmdb file.

class lsun_bedroom_train(datasets_base):
    def __init__(self, path, img_size=256):
        self.all_keys = self.read_image_key_file_json(path + '/key_bedroom.json')
        self.db = lmdb.open(path+"/bedroom_train_lmdb", readonly=True).begin(write=False)
        super(lsun_bedroom_train, self).__init__(flip=1, resize_to=img_size, crop_to=0)

    def __len__(self):
        return len(self.all_keys)

    def get_example(self, i):
        id = self.all_keys[i]
        img = None
        val = self.db.get(id.encode())
        img = cv2.imdecode(np.fromstring(val, dtype=np.uint8), 1)
        img = self.do_augmentation(img)
        img = self.preprocess_image(img)
        return img
keisukefukuda commented 6 years ago

hi @Aixile ,

Yes, as you pointed out by yourself, the error happens because self.db is not picklable. Since ChainerMN needs to send the dataset over a network, the data must be serializable.

Although I'm not familiar with lmdb, I guess one way to avoid this problem is to delay the initialization of lmdb objet, which is now done in __init__.

class lsun_bedroom_train(datasets_base):
    def _init_db(self):    # <-------------
        self.db = lmdb.open(path+"/bedroom_train_lmdb", readonly=True).begin(write=False)

    def __init__(self, path, img_size=256):
        self.all_keys = self.read_image_key_file_json(path + '/key_bedroom.json')
        self.db = None
        super(lsun_bedroom_train, self).__init__(flip=1, resize_to=img_size, crop_to=0)

    def __len__(self):
        return len(self.all_keys)

    def get_example(self, i):
        if self.db is None:    # <-------------
            self._init_db()
        id = self.all_keys[i]
        img = None
        val = self.db.get(id.encode())
        img = cv2.imdecode(np.fromstring(val, dtype=np.uint8), 1)
        img = self.do_augmentation(img)
        img = self.preprocess_image(img)
        return img

I didn't test it but hope it helps :)

Aixile commented 6 years ago

This strategy solves my problem. Thanks!