MIC-DKFZ / batchgenerators

A framework for data augmentation for 2D and 3D image classification and segmentation
Apache License 2.0
1.09k stars 221 forks source link

Odd scaling of slowdown when going from 2D to 3D #33

Closed ksarma closed 5 years ago

ksarma commented 5 years ago

Hi,

I'm not sure if this is a bug, an error on my part, or just the way things are expected to be but I thought it was odd so I figured I'd bring it to your attention.

I've recently switched from looking at 2D images to 3D volumes, and I've found that the slowdown in augmenting (using SpatialTransform) is significantly higher than I would have expected.

I did some tests using dummy code from @FabianIsensee in #5 and compared the speed of the augmenter there using:

1) Dummy data size (32, 1, 256, 256) and patch_size (128, 128) 2) Dummy data size (32, 1, 256, 256, 256) and patch_size (128, 128, 128)

and found that with the SingleThreadedAugmenter, generation took ~0.2s/batch for test 1, and 57s/batch for test 2 -- a factor of about 300, rather than the 128 I was expecting.

Actual code here:

from batchgenerators.transforms.color_transforms import \
    BrightnessMultiplicativeTransform
from batchgenerators.transforms.spatial_transforms import SpatialTransform
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.transforms.sample_normalization_transforms import \
    MeanStdNormalizationTransform
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
import numpy as np
from time import time

# First 2D

class DummyLoader(SlimDataLoaderBase):
    def __init__(self):
        super(DummyLoader, self).__init__(None, None, None)

    def generate_train_batch(self):
        return {'data': np.random.random((32, 1, 256, 256))}

transforms = Compose([
    BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3),
                                      per_channel=True),

    SpatialTransform(patch_size=(128, 128),
                     do_elastic_deform=True,
                     alpha=(90., 750.),
                     sigma=(9., 11.),
                     do_scale=True,
                     random_crop=False,
                     do_rotation=False,
                     order_data=1,
                     border_mode_data='reflect'),

    MeanStdNormalizationTransform(mean=[0.485],
                                  std=[0.229])
])

single_threaded_gen = SingleThreadedAugmenter(DummyLoader(), transforms)
multi_threaded_gen_one_thread = MultiThreadedAugmenter(DummyLoader(), transforms, 1, 1, None)
multi_threaded_gen_eight_threads = MultiThreadedAugmenter(DummyLoader(), transforms, 8, 1, None)

num_batches_warmup = 16
num_batches_run = 16

print("Running 2D tests")

####### SingleThreadedAugmenter #######
# warumup
_ = [next(single_threaded_gen) for _ in range(num_batches_warmup)]
# run
start = time()
_ = [next(single_threaded_gen) for _ in range(num_batches_run)]
end = time()
print("Generated %d batches with SingleThreadedAugmenter in %f seconds; %f s/batch" % (num_batches_run, end - start, (end - start) / num_batches_run))

####### MultiThreadedAugmenter (1 thread) #######
# warumup
_ = [next(multi_threaded_gen_one_thread) for _ in range(num_batches_warmup)]
# run
start = time()
_ = [next(multi_threaded_gen_one_thread) for _ in range(num_batches_run)]
end = time()
print("Generated %d batches with MultiThreadedAugmenter (1 thread) in %f seconds; %f s/batch" % (num_batches_run, end - start, (end - start) / num_batches_run))

####### MultiThreadedAugmenter (8 threads) #######
# warumup
_ = [next(multi_threaded_gen_eight_threads) for _ in range(num_batches_warmup)]
# run
start = time()
_ = [next(multi_threaded_gen_eight_threads) for _ in range(num_batches_run)]
end = time()
print("Generated %d batches with MultiThreadedAugmenter (8 threads) in %f seconds; %f s/batch" % (num_batches_run, end - start, (end - start) / num_batches_run))

# Now 3D

class DummyLoader(SlimDataLoaderBase):
    def __init__(self):
        super(DummyLoader, self).__init__(None, None, None)

    def generate_train_batch(self):
        return {'data': np.random.random((32, 1, 256, 256, 256))}

transforms = Compose([
    BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3),
                                      per_channel=True),

    SpatialTransform(patch_size=(128, 128, 128),
                     do_elastic_deform=True,
                     alpha=(90., 750.),
                     sigma=(9., 11.),
                     do_scale=True,
                     random_crop=False,
                     do_rotation=False,
                     order_data=1,
                     border_mode_data='reflect'),

    MeanStdNormalizationTransform(mean=[0.485],
                                  std=[0.229])
])

single_threaded_gen = SingleThreadedAugmenter(DummyLoader(), transforms)
multi_threaded_gen_one_thread = MultiThreadedAugmenter(DummyLoader(), transforms, 1, 1, None)
multi_threaded_gen_eight_threads = MultiThreadedAugmenter(DummyLoader(), transforms, 8, 1, None)

num_batches_warmup = 16
num_batches_run = 16

print("Running 3D tests")
####### SingleThreadedAugmenter #######
# warumup
_ = [next(single_threaded_gen) for _ in range(num_batches_warmup)]
# run
start = time()
_ = [next(single_threaded_gen) for _ in range(num_batches_run)]
end = time()
print("Generated %d batches with SingleThreadedAugmenter in %f seconds; %f s/batch" % (num_batches_run, end - start, (end - start) / num_batches_run))

####### MultiThreadedAugmenter (1 thread) #######
# warumup
_ = [next(multi_threaded_gen_one_thread) for _ in range(num_batches_warmup)]
# run
start = time()
_ = [next(multi_threaded_gen_one_thread) for _ in range(num_batches_run)]
end = time()
print("Generated %d batches with MultiThreadedAugmenter (1 thread) in %f seconds; %f s/batch" % (num_batches_run, end - start, (end - start) / num_batches_run))

####### MultiThreadedAugmenter (8 threads) #######
# warumup
_ = [next(multi_threaded_gen_eight_threads) for _ in range(num_batches_warmup)]
# run
start = time()
_ = [next(multi_threaded_gen_eight_threads) for _ in range(num_batches_run)]
end = time()
print("Generated %d batches with MultiThreadedAugmenter (8 threads) in %f seconds; %f s/batch" % (num_batches_run, end - start, (end - start) / num_batches_run))

and output here:

Running 2D tests Generated 16 batches with SingleThreadedAugmenter in 3.183878 seconds; 0.198992 s/batch Generated 16 batches with MultiThreadedAugmenter (1 thread) in 3.397480 seconds; 0.212343 s/batch Generated 16 batches with MultiThreadedAugmenter (8 threads) in 0.509470 seconds; 0.031842 s/batch Running 3D tests Generated 16 batches with SingleThreadedAugmenter in 919.891915 seconds; 57.493245 s/batch Generated 16 batches with MultiThreadedAugmenter (1 thread) in 931.741837 seconds; 58.233865 s/batch Generated 16 batches with MultiThreadedAugmenter (8 threads) in 232.098892 seconds; 14.506181 s/batch

PS: Thanks so much for your work in making this excellent package! It really is absolutely fantastic

FabianIsensee commented 5 years ago

Hi there, thank you for the detailed information and especially also for providing code! It is so nice to work on issues when you can exactly reproduce what the problem is.

I had a hard time understanding why the things happen as they are but I think I found an explanation. The cause of the discrepancy lies in the SpatialTransform. If you for example remove that and instead use a bunch of Transforms that operate on a pixel level, then the discrepancy between 2d and 3d is gone.

Please have a look at the following code. I made some changes to yours: 1) batches now have the exact same number of voxels for 2d and 3d. That is more fair (less process communication overhead for example) and easier to compare 2) transforms that operate on a per-pixel level should come after the SpatialTransform. You don't need to adapt the brightness for pixels that are cropped away anyways. 3) I use 200 BrightnessMultiplicativeTransform to create some more CPU load and remove the SpatialTransform 4) DummyLoader store their batches as class variable so that they don't have to be re-created 5) I removed everything except the multi_threaded_gen_eight_threads to be able to test things quicker

from batchgenerators.dataloading import SingleThreadedAugmenter, MultiThreadedAugmenter
from batchgenerators.transforms.color_transforms import \
    BrightnessMultiplicativeTransform
from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
from batchgenerators.transforms.spatial_transforms import SpatialTransform
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.transforms.sample_normalization_transforms import \
    MeanStdNormalizationTransform
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
import numpy as np
from time import time, sleep

num_batches_warmup = 16
num_batches_run = 16

# First 2D

class DummyLoader(SlimDataLoaderBase):
    def __init__(self):
        super(DummyLoader, self).__init__(None, None, None)
        self.data = np.random.random((128, 1, 256, 256))

    def generate_train_batch(self):
        return {'data': self.data}

transforms = Compose([
    # SpatialTransform(patch_size=(128, 128),
    #                  do_elastic_deform=True,
    #                  alpha=(90., 750.),
    #                  sigma=(9., 11.),
    #                  do_scale=True,
    #                  random_crop=False,
    #                  do_rotation=False,
    #                  order_data=1,
    #                  border_mode_data='reflect'),
    CenterCropTransform((128, 128)),
    *[BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3),
                                      per_channel=True) for _ in range(200)],

    MeanStdNormalizationTransform(mean=[0.485],
                                  std=[0.229]),
])

multi_threaded_gen_eight_threads = MultiThreadedAugmenter(DummyLoader(), transforms, 8, 1, None)

print("Running 2D tests")
####### MultiThreadedAugmenter (8 threads) #######
# warumup
multi_threaded_gen_eight_threads.restart()
_ = [next(multi_threaded_gen_eight_threads) for _ in range(multi_threaded_gen_eight_threads.num_processes * 2)]
# run
start = time()
_ = [next(multi_threaded_gen_eight_threads) for _ in range(num_batches_run)]
end = time()
print("Generated %d batches with MultiThreadedAugmenter (8 threads) in %f seconds; %f s/batch" % (num_batches_run, end - start, (end - start) / num_batches_run))

# Now 3D

class DummyLoader(SlimDataLoaderBase):
    def __init__(self):
        super(DummyLoader, self).__init__(None, None, None)
        self.data = np.random.random((1, 1, 256, 256, 128))

    def generate_train_batch(self):
        return {'data': self.data}

transforms = Compose([
    # SpatialTransform(patch_size=(128, 128, 128),
    #                  do_elastic_deform=True,
    #                  alpha=(90., 750.),
    #                  sigma=(9., 11.),
    #                  do_scale=True,
    #                  random_crop=False,
    #                  do_rotation=False,
    #                  order_data=1,
    #                  border_mode_data='reflect'),
    CenterCropTransform((128, 128, 128)),
    *[BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3),
                                      per_channel=True) for _ in range(200)],
    MeanStdNormalizationTransform(mean=[0.485],
                                  std=[0.229]),
])

multi_threaded_gen_eight_threads = MultiThreadedAugmenter(DummyLoader(), transforms, 8, 1, None)

print("Running 3D tests")
####### MultiThreadedAugmenter (8 threads) #######
# warumup
multi_threaded_gen_eight_threads.restart()
_ = [next(multi_threaded_gen_eight_threads) for _ in range(multi_threaded_gen_eight_threads.num_processes * 2)]
# run
start = time()
_ = [next(multi_threaded_gen_eight_threads) for _ in range(num_batches_run)]
end = time()
print("Generated %d batches with MultiThreadedAugmenter (8 threads) in %f seconds; %f s/batch" % (num_batches_run, end - start, (end - start) / num_batches_run))

Running 2D tests Generated 16 batches with MultiThreadedAugmenter (8 threads) in 2.841806 seconds; 0.177613 s/batch Running 3D tests Generated 16 batches with MultiThreadedAugmenter (8 threads) in 2.532192 seconds; 0.158262 s/batch

As you can see - the difference is gone. Now why is that? I don't have a really good answer for you I am afraid. My best guess is that interpolating in 3D space is a lot more computationally intensive. If you do linear interpolation in 2D, you need to use the 4 nearest pixels and find the new value in those. In 3D, this interpolation needs the 9 nearest pixels - that is more than 2x the amount of computation.

Best, Fabian

FabianIsensee commented 5 years ago

Another thing I just noticed is that the elastic deformation in 3D is expensive to compute! There is a Gaussian smooting involved that takes a while - and this filter also takes longer in 3D