fepegar / torchio

Medical imaging toolkit for deep learning
http://www.torchio.org
Apache License 2.0
2.04k stars 240 forks source link

Issue when using PyTorch Lightning and TorchIO #1189

Open rousseau opened 3 days ago

rousseau commented 3 days ago

Is there an existing issue for this?

Bug summary

Crash when using Lightning and TorchIO during training.

Code for reproduction

import torch
import torch.nn as nn 
import pytorch_lightning as pl
import torchio as tio
import monai

#%% Lightning module
class meta_model(pl.LightningModule):
    def __init__(self): 
        super().__init__()  
        self.unet = monai.networks.nets.Unet(spatial_dims=3, in_channels=1, out_channels=1, channels=(8,16,32), strides=(2,2))
        self.loss = nn.MSELoss()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
        return optimizer

    def training_step(self, batch, batch_idx):
        target = batch['t1'][tio.DATA]
        source = batch['t1'][tio.DATA]
        loss = self.loss(target, self.unet(source))
        return loss

#%% 
subjects = []
subject = tio.datasets.Colin27()
subjects.append(subject) 
training_set = tio.SubjectsDataset(subjects)    
training_loader = torch.utils.data.DataLoader(training_set, batch_size=1)

#%%
net = meta_model()    
trainer_reg = pl.Trainer(max_epochs=1)          
trainer_reg.fit(net, training_loader)

Actual outcome

The issue seems to be related to torchio/data/image.py", line 471, in _parse_path TypeError: The path argument cannot be a dictionary

Error messages

Epoch 0:   0%|                                                                                          | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/rousseau/Experiments/tmp/bug_lightning_torchio.py", line 38, in <module>
    trainer_reg.fit(net, training_loader)  
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1025, in _run_stage
    self.fit_loop.run()
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 223, in advance
    batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=0)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 277, in batch_to_device
    return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 356, in _apply_batch_transfer_handler
    batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 345, in _call_batch_hook
    return trainer_method(trainer, hook_name, *args)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 167, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/pytorch_lightning/core/hooks.py", line 611, in transfer_batch_to_device
    return move_data_to_device(batch, device)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/lightning_fabric/utilities/apply_func.py", line 110, in move_data_to_device
    return apply_to_collection(batch, dtype=_TransferableDataType, function=batch_to)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/lightning_utilities/core/apply_func.py", line 72, in apply_to_collection
    return _apply_to_collection_slow(
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/lightning_utilities/core/apply_func.py", line 104, in _apply_to_collection_slow
    v = _apply_to_collection_slow(
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/lightning_utilities/core/apply_func.py", line 118, in _apply_to_collection_slow
    return elem_type(OrderedDict(out))
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/torchio/data/image.py", line 858, in __init__
    super().__init__(*args, **kwargs)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/torchio/data/image.py", line 179, in __init__
    self.path = self._parse_path(path)
  File "/home/rousseau/miniconda3/lib/python3.10/site-packages/torchio/data/image.py", line 471, in _parse_path
    raise TypeError('The path argument cannot be a dictionary')
TypeError: The path argument cannot be a dictionary

Expected outcome

It's a dummy code. The expected outcome is that the training should work.

System info

Platform:   Linux-6.8.0-44-generic-x86_64-with-glibc2.39
TorchIO:    0.19.9
PyTorch:    2.4.1+cu124
SimpleITK:  2.4.0 (ITK 5.4)
NumPy:      1.26.3
Python:     3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0]
romainVala commented 3 days ago

Hi Francois

Interesting, this is a good example of side effect related to this issue #1179

First solution is to downgrade PyTorch to a version < 2.3

Second solution is to use the "hack" propose here https://github.com/fepegar/torchio/issues/1179#issuecomment-2254715480

and then addapt your code

instead of using tio.datasets.Colin27() and tio.SubjectsDataset(subjects) you must use the redefined Subject and SubjectsDataset class (as proposed by c-winder) I did :

suj = Subject(t1=subject.t1,head=subject.head)
training_set = SubjectsDataset([suj])

and also instead of

    def training_step(self, batch, batch_idx):
        target = batch['t1'][tio.DATA]
        source = batch['t1'][tio.DATA]

use

 def training_step(self, batch, batch_idx):
     target = batch['t1']
     source = batch['t1']

I just realize that c-winder, also propose this https://github.com/fepegar/torchio/issues/1179#issuecomment-2254691162 then if you use his SubjectDataLoader instead of torch.utils.data.DataLoader, it will work too (with same change in training_step function

Note that you then get an other error due to unet not dealing with odd shape, which can be solved by using CropOrPad for instance

tc = tio.CropOrPad(target_shape=[184,216,184])
suj = tc(suj)
fepegar commented 2 days ago

Thank you both. You can also take a look at EnsureShapeMultiple for the U-Net error.