MIC-DKFZ / batchgenerators

A framework for data augmentation for 2D and 3D image classification and segmentation
Apache License 2.0
1.09k stars 221 forks source link

MultiThreadedAugmenter questions #61

Open arvind1609 opened 4 years ago

arvind1609 commented 4 years ago

Hello,

Thanks for sharing this wonderful package. I'm trying to augment 3D images and randomly sample a batch from my data with segmentation maps.

ip_data = np.zeros((20,256,256,24)) ip_seg = np.zeros((20,256,256,24)) data_zip = (ip_data, ip_seg)

class DataLoader(DataLoaderBase):

 def __init__(self, data, BATCH_SIZE=2, num_batches=None, seed=False):
        super(DataLoader, self).__init__(data, BATCH_SIZE, num_batches, seed) 
        # data is now stored in self._data.

    def generate_train_batch(self):
        # usually you would now select random instances of your data. We only have one therefore we skip this
        random_val = np.random.choice(range(20), self.BATCH_SIZE, replace=False)
        img = self._data[0][random_val]
        label = self._data[1][random_val]

        print(random_val)
        # The camera image has only one channel. Our batch layout must be (b, c, x, y, z).
        img = np.expand_dims(img, axis = 1)
        label = np.expand_dims(label, axis = 1)

        # now construct the dictionary and return it. np.float32 cast because most networks take float
        return {'data':img.astype(np.float32), 
            'seg': label.astype(np.float32),
            'subjects': random_val}
batchgen = DataLoader(data_zip, 2, None, False)

spatial_transform = SpatialTransform((192,192,24), np.array((192,192,24)) // 2, 
                 do_elastic_deform=False, alpha=(0., 1500.), sigma=(30., 50.),
                 do_rotation=False, angle_x=(0, 0.5 * np.pi),
                 do_scale=False, scale=(0.1,0.3), 
                 border_mode_data='constant', border_cval_data=0, order_data=1,
                 random_crop=False)

multithreaded_generator = MultiThreadedAugmenter(batchgen, spatial_transform, 4, 2, seeds=None)

spatial_data = next(multithreaded_generator)

I had 2 questions:

(1) When I use the next on the multithreaded_generator for the first time, a number of arrays are generated from the print statement in the class. Is this the expected functionality?

(2) I'm assuming the returned dictionary from the class is the data structure that matters. So, I could have any custom operations in the class as long as the shape of 'data' and 'seg' is (b,c,x,y,z).

Apologies if this is trivial. Thanks!