Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.79k stars 1.07k forks source link

Can DataLoader be made to support dynamic batch shapes? #6264

Closed nslay closed 1 year ago

nslay commented 1 year ago

Is your feature request related to a problem? Please describe. A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

I'm always frustrated when I need to make Transforms that consistently produce the same size tensor for DataLoader. DataLoader does not take kindly to variably-sized tensors. This is not a problem for 2D or patch-based methods, but it is annoying for FCNs or other methods that can operate on whole images (because they're implicitly sliding-window methods anyway). In my case, I have a model that needs to see the whole shape of a segmentation mask because it learns something about the shape. Randomly cropping the segmentation mask to appease DataLoader teaches it that a structure can be unnaturally flat sometimes. Computing the maximum shape a mask can contain results in a pessimistic 3D tensor shape that would exhaust GPU memory in appreciable batch sizes (e.g. 8). However, I have my own horribly-written custom DataLoader from before learning MONAI that can dynamically determine the size of a tensor to contain the variably-sized batch instances and I can train with batch sizes even as high 8 or 16! That's because no segmentation mask (even when batched) occupies the pessimistic maximum extent in size.

Describe the solution you'd like What I'd like to see is DataLoader being able to read variably-sized tensors from Transforms and then determining the minimum spatial size for the output batch tensor that can fully store even the largest batch instance in the batch. It should provide a pad strategy (e.g. fill_value) option for smaller tensors in the batch.

Describe alternatives you've considered A clear and concise description of any alternative solutions or features you've considered.

In order to workaround DataLoader's limitation with variably-sized tensors, I use SpatialPadd with a spatial_size being the maximum extent of segmentation masks (to ensure DataLoader is happy!) and I use a fill_value=-1 on the "label" key. After reading a batch from DataLoader, I then compute the minimum bounding shape containing ybatch >= 0. Then I spatially crop the batch and transfer it to the GPU. This is a CPU-inefficient way to workaround DataLoader's limitations... but it works!

If I could fit my batch tensors in memory without this strategy, I already use, for example, ignore_index=-1 in nn.CrossEntropyLoss to prevent loss and gradient calculations from happening in these dont-care regions. This prevents the model from learning in places SpatialPadd applied padding. But this does make the GPU work harder and consume more GPU memory (which is the problem for me when using MONAI).

Additional context Add any other context or screenshots about the feature request here.

I completely understand that my request may make DataLoader less computationally optimal (I don't care though! I just want GPU memory usage efficiency!). Things like pin memory and so-forth would probably be harder to effectively use if you have to deal with the variably-sized tensors.

wyli commented 1 year ago

have you tried train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate)

that's from monai.data.utils.pad_list_data_collate

https://github.com/Project-MONAI/MONAI/blob/1005eacd1fa55fdae7df07a243bd682f170c17ea/monai/data/utils.py#L640-L651

nslay commented 1 year ago

Never noticed the collate_fn argument to DataLoader. I'll give it a shot later today! Yeah, I was looking at those Collate transforms, but my initial impression was that they were just another Transform and wasn't clear how using them on Dataset side was going to appease DataLoader.

nslay commented 1 year ago

OK, this seems to work. Now I have another problem with collate_fn. It does not work with Compose transforms. It will with a lambda to pad_list_data_collate or to a standalone PadListDataCollate.

So why would I need a Compose? Well, some models need spatial sizes to be divisible by a power of 2. So I need to followup PadListDataCollate with, for example, DivisiblePadd. I also need to do some other special transform on the "label" key.

So here's how it complains with Compose containing only PadListDataCollate (and no other transform)

Traceback (most recent call last): File "/gpfs/gsfs12/users/layns/ProstateCAD/KITS19_2d_3d_monai/RCCSeg.py", line 585, in cad.Train(train_list, snapshot_dir, val_list=val_list) File "/gpfs/gsfs12/users/layns/ProstateCAD/KITS19_2d_3d_monai/RCCSeg.py", line 511, in Train for batch_dict in train_loader: File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in next data = self._next_data() File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1376, in _next_data return self._process_data(data) File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1402, in _process_data data.reraise() File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/torch/_utils.py", line 461, in reraise raise exception RuntimeError: Caught RuntimeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/monai/transforms/transform.py", line 101, in apply_transform return [_apply_transform(transform, item, unpack_items) for item in data] File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/monai/transforms/transform.py", line 101, in return [_apply_transform(transform, item, unpack_items) for item in data] File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/monai/transforms/transform.py", line 66, in _apply_transform return transform(parameters) File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/monai/transforms/croppad/batch.py", line 76, in call is_list_of_dicts = isinstance(batch[0], dict) KeyError: 0

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop data = fetcher.fetch(index) File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch return self.collate_fn(data) File "/gpfs/gsfs12/users/layns/torch_monaiv100/lib/python3.9/site-packages/monai/transforms/compose.py", line 174, in call input = apply_transform(transform, input, self.map_items, self.unpack_items, self.log_stats) File "/gpfs/gsfs12/users/layns/torch_monai_v100/lib/python3.9/site-packages/monai/transforms/transform.py", line 129, in apply_transform raise RuntimeError(f"applying transform {transform}") from e RuntimeError: applying transform <monai.transforms.croppad.batch.PadListDataCollate object at 0x2aab84501730>

I don't know why Compose is unhappy as PadListCollateData transform works by itself without being wrapped by Compose

I'll go ahead and close and make a problem report.