Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.91k stars 1.09k forks source link

Compose cannot work with both Decollated and MultiSampleTrait transforms #8186

Open ziw-liu opened 3 weeks ago

ziw-liu commented 3 weeks ago

Describe the bug If Compose is used with the following types of transforms, the third transform will not get the expected input:

  1. Decollated, which splits the input. This transform looks like:
    Callable[dict[str, Tensor], list[dict[str, Tensor]]]
  2. MultiSampleTrait, which further splits the input. Note that the accumulated effect is now:
    Callable[dict[str, Tensor], list[list[dict[str, Tensor]]]]
  3. MapTransform, which expect a dict[str, Tensor] from the caller (Compose in this case), will error.

To Reproduce Code snippet:

import torch
from monai.transforms import Compose, RandSpatialCropSamplesd, Resized, Decollated

transform = Compose(
    [
        Decollated(keys=["img"]),
        RandSpatialCropSamplesd(
            keys=["img"],
            roi_size=(1, 192, 192),
            num_samples=2,
            max_roi_size=(1, 320, 320),
            random_center=True,
            random_size=True,
        ),
        Resized(keys=["img"], spatial_size=(1, 224, 224)),
    ]
)

img = {"img": torch.rand(3, 1, 1, 512, 512)}

transform(img)

This will error:

  File ".../python3.11/site-packages/monai/transforms/spatial/dictionary.py", line 846, in __call__
    d = dict(data)
        ^^^^^^^^^^
ValueError: dictionary update sequence element #0 has length 1; 2 is required

Expected behavior Right now I can get around this by inserting a custom transform that flattens the nested list. But Compose should handle this just like it handles usual MultiSampleTrait transforms.

Environment

================================
Printing MONAI config...
================================
MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.5.0+cu124
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /hpc/mydata/<username>/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: 0.24.0
scipy version: 1.14.0
Pillow version: 10.4.0
Tensorboard version: 2.17.1
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.20.0+cu124
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.2
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...
================================
System: Linux
Linux version: Rocky Linux 8.10 (Green Obsidian)
Platform: Linux-4.18.0-553.16.1.el8_10.x86_64-x86_64-with-glibc2.28
Processor: x86_64
Machine: x86_64
Python version: 3.11.9
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: [popenfile(path='/home/<username>/.vscode-server/data/logs/20241031T101853/remoteagent.log', fd=19, position=5336, mode='a', flags=33793), popenfile(path='/home/<username>/.vscode-server/data/logs/20241031T101853/ptyhost.log', fd=20, position=4686, mode='a', flags=33793)]
Num physical CPUs: 16
Num logical CPUs: 16
Num usable CPUs: 16
CPU usage (%): [8.5, 8.5, 3.5, 8.1, 3.9, 3.9, 3.2, 4.6, 5.0, 18.6, 23.9, 3.5, 3.5, 3.6, 4.3, 4.6]
CPU freq. (MHz): 2935
Load avg. in last 1, 5, 15 mins (%): [0.6, 0.5, 1.4]
Disk usage (%): 93.3
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 503.8
Available memory (GB): 440.0
Used memory (GB): 27.3

================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 12.4
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None
cuDNN version: 90100
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA A40
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 84
GPU 0 Total memory (GB): 44.7
GPU 0 CUDA capability (maj.min): 8.6