Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.5k stars 1.01k forks source link

RandWeightedCropd in batch #7851

Open 7oud opened 2 weeks ago

7oud commented 2 weeks ago

Describe the bug using RandWeightedCropd to random crop patches from images with different size, program crashed due toweight_map with differnet shape.

To Reproduce

  1. after spacing, the shape of image 1 is (1, 258, 358, 358) and image 2 is (1, 245, 424, 424)
  2. load weight maps, which have the same size as the corresponding image, (1, 258, 358, 358) and (1, 245, 424, 424)
  3. use RandWeightedCropd to crop patches with fixed size (224, 224, 224)
  4. use DataLoader(ds, batch_size=2, collate_fn=pad_list_data_collate), the program crashed due the different size of weight_map
  5. if using RandSpatialCropSamplesd without weight map, that's OK
def get_transforms():
    transforms = [
        LoadImaged(keys=['image']),
        EnsureChannelFirstd(keys=['image']),
        Orientationd(keys=['image'], axcodes='SPL'),
        Spacingd(keys=['image'], pixdim=[0.5, 0.5, 0.5], mode=["bilinear"]),
        LoadImaged(keys=['wgtmap']),
        EnsureChannelFirstd(keys=['wgtmap']),
        RandWeightedCropd(
            keys=['image'],
            spatial_size=(224, 224, 224),
            num_samples=1,
            w_key='wgtmap',
        ),
        # RandSpatialCropSamplesd(
        #     keys=['image'],
        #     roi_size=(224, 224, 224),
        #     num_samples=1,
        # ),
    ]

    return Compose(transforms)

data1 = {
    'image': os.path.join(root, 'image035.nii.gz'),
    'wgtmap': os.path.join(root, 'image035-wgt.npy'),
}
data2 = {
    'image': os.path.join(root, '10106129_img-orig.nii.gz'),
    'wgtmap': os.path.join(root, '10106129_img-orig-wgt.npy'),
}

trans = get_transforms()
ds = Dataset(data=[data1, data2], transform=trans)
dl = DataLoader(ds, batch_size=2, num_workers=2, collate_fn=pad_list_data_collate)

for i, batch_data in enumerate(dl):
    inputs = batch_data["image"]
    print(inputs.shape)
    wgtmap = batch_data["wgtmap"]
    print(wgtmap.shape)

OUTPUT

collate dict key "image" out of 2 keys

collate/stack a list of tensors collate dict key "wgtmap" out of 2 keys collate/stack a list of tensors E: stack expects each tensor to be equal size, but got [1, 258, 358, 358] at entry 0 and [1, 245, 424, 424] at entry 1, shape [torch.Size([1, 258, 358, 358]), torch.Size([1, 245, 424, 424])] in collate([metatensor([[[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.],

Traceback (most recent call last): File "/Users/z/repo/github/7oud/tst_py_store/monai_store/transform copy.py", line 83, in main() File "/Users/xxx/repo/github/7oud/tst_py_store/monai_store/transform copy.py", line 68, in main for i, batch_data in enumerate(dl): File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 628, in next data = self._next_data() File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1333, in _next_data return self._process_data(data) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1359, in _process_data data.reraise() File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/_utils.py", line 543, in reraise raise exception RuntimeError: Caught RuntimeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 516, in list_data_collate ret = collate_fn(data) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate return collate(batch, collate_fn_map=default_collate_fn_map) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 128, in collate return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 128, in return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 120, in collate return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 458, in collate_meta_tensor_fn collated = collate_fn(batch) # type: ignore File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 163, in collate_tensor_fn return torch.stack(batch, 0, out=out) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/meta_tensor.py", line 282, in torch_function ret = super().torch_function(func, types, args, kwargs) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/_tensor.py", line 1279, in __torch_function__ ret = func(*args, **kwargs) RuntimeError: stack expects each tensor to be equal size, but got [1, 258, 358, 358] at entry 0 and [1, 245, 424, 424] at entry 1

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop data = fetcher.fetch(index) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 61, in fetch return self.collate_fn(data) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 696, in pad_list_data_collate return PadListDataCollate(method=method, mode=mode, **kwargs)(batch) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/transforms/croppad/batch.py", line 114, in call return list_data_collate(batch) File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 529, in list_data_collate raise RuntimeError(re_str) from re RuntimeError: stack expects each tensor to be equal size, but got [1, 258, 358, 358] at entry 0 and [1, 245, 424, 424] at entry 1

MONAI hint: if your transforms intentionally create images of different shapes, creating your DataLoader with collate_fn=pad_list_data_collate might solve this problem (check its documentation).

Environment MONAI version: 1.3.1 Numpy version: 1.23.5 Pytorch version: 1.13.1 MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False MONAI rev id: 96bfda00c6bd290297f5e3514ea227c6be4d08b4 MONAI file: /Users//opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/init.py