AI Toolkit for Healthcare Imaging
Apache License 2.0
5.93k stars 1.09k forks source link

OneOf transform fails in Dataloader if num_workets>0 #8222

Closed sakvaua closed 1 week ago

sakvaua commented 1 week ago

I'm new to MONAI and could be doing something wrong but for some unclear reason, my dataloader fails if I include OneOf transform. It only fails if I use multiprocessing (num_workers in Dataloader>0). Everything works fine If I apply all the transforms included in the OneOf without the wrapper.

To Reproduce Here is a minimum working example

import numpy as np
from monai.transforms import *

my_num_samples = 20
train_batch_size = 1

non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    Orientationd(keys=["image", "label"], axcodes="RAS")

random_transforms = Compose([
        keys=["image", "label"],
        spatial_size=[10, 10, 10],
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
    RandAxisFlipd(keys=["image", "label"], prob=0.5),
        RandRotated(keys=["image", "label"], prob=0.25, range_x=0.2,range_y=0.2,range_z=0.2),# - too much!!!
        RandZoomd(keys=["image", "label"], prob=0.25,min_zoom=0.9, max_zoom=1.1),
        Rand3DElasticd(keys=["image", "label"], prob=0.25,sigma_range=(5,7), magnitude_range=(50,150)),
        RandGridDistortiond(keys=["image", "label"], prob=0.25,num_cells=5, distort_limit=(-0.03, 0.03))])

# Create dummy data list
for i in range(5):

train_ds = Dataset(data=data_list, transform=non_random_transforms)
train_ds = Dataset(data=train_ds, transform=random_transforms)

#This works
train_loader = DataLoader(
for i in range(2):
    for p in train_loader:

#Throws Exception
train_loader = DataLoader(
for i in range(10):
    for p in train_loader:

Expected behavior I expect the code not to fail at all or at least fail in both cases.

Ensuring you use the relevant python executable, please paste the output of:

Printing MONAI config...
MONAI version: 1.5.dev2445
Numpy version: 1.26.4
Pytorch version: 2.5.1+cu124
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 2af9926d853086b264680adcf954bf3232f5ec32
MONAI __file__: /home/<username>/anaconda3/lib/python3.10/site-packages/monai/

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.4.0
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.14.1
Pillow version: 10.4.0
Tensorboard version: 2.16.2
gdown version: 5.2.0
TorchVision version: 0.20.1+cu124
tqdm version: 4.66.5
lmdb version: 1.5.1
psutil version: 5.9.0
pandas version: 2.2.2
einops version: 0.6.1
transformers version: 4.40.2
mlflow version: 2.17.2
pynrrd version: 1.1.1
clearml version: 1.16.5

For details about installing the optional dependencies, please visit:

Printing system config...
System: Linux
Linux version: Linux Mint 21.1
Platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Processor: x86_64
Machine: x86_64
Python version: 3.10.9
Process name: python
Command: ['/home/sakvaua/anaconda3/bin/python', '-m', 'ipykernel_launcher', '-f', '/home/sakvaua/.local/share/jupyter/runtime/kernel-a66532bf-3ca8-4fab-9c45-e2fdf16371b1.json']
Open files: [popenfile(path='/home/sakvaua/.ipython/profile_default/history.sqlite', fd=47, position=0, mode='r+', flags=688130), popenfile(path='/home/sakvaua/.ipython/profile_default/history.sqlite', fd=48, position=0, mode='r+', flags=688130), popenfile(path='/home/sakvaua/anaconda3/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf', fd=108, position=71562, mode='r', flags=557056)]
Num physical CPUs: 22
Num logical CPUs: 44
Num usable CPUs: 44
CPU usage (%): [10.9, 9.0, 7.2, 6.4, 6.5, 7.3, 6.8, 6.8, 6.5, 5.9, 6.3, 6.1, 6.1, 6.3, 5.7, 6.2, 6.2, 5.8, 5.9, 6.2, 6.3, 4.7, 4.4, 5.9, 5.5, 6.1, 5.7, 5.4, 6.3, 6.7, 5.9, 6.8, 5.7, 6.2, 6.3, 5.4, 5.7, 5.9, 6.0, 5.7, 5.8, 5.9, 5.8, 6.2]
CPU freq. (MHz): 1
Load avg. in last 1, 5, 15 mins (%): [2.9, 2.0, 2.2]
Disk usage (%): 40.0
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 125.7
Available memory (GB): 67.7
Used memory (GB): 45.2

Printing GPU config...
Num GPUs: 1
Has CUDA: True
CUDA version: 12.4
cuDNN enabled: True
cuDNN version: 90100
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA GeForce RTX 4090
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 128
GPU 0 Total memory (GB): 23.6
GPU 0 CUDA capability (maj.min): 8.9
KumoLiu commented 1 week ago

Hi @sakvaua, based on the error message, it seems the issue arises from a dtype mismatch after the random transformation. To resolve this, you can include EnsureTyped after the OneOf transform to ensure the data has the correct dtype.

EnsureTyped(keys=["image", "label"], dtype=[torch.float32, torch.uint8])

Hope it helps, thanks. (Move to discussion for now)