fepegar / torchio

Medical imaging toolkit for deep learning
https://torchio.org
Apache License 2.0
2.07k stars 240 forks source link

Cannot use Patch Queue together with multi-GPU training (via pytorch-lightning) #890

Open jxchen01 opened 2 years ago

jxchen01 commented 2 years ago

Is there an existing issue for this?

Bug summary

I am testing to use Pytorch-lightning to handle model training (easy to use multi-gpu training and other training tricks) while using the TorchIO as dataloader. But, I always get errors.

Code for reproduction

# define a new datamodule with Patch Queue
class myDataModule(pl.LightningDataModule):
    def __init__(self, data_cfg):
       # do something here
       self.config = data_cfg
       self.preproc = XXXX  # pseudocode
       self.transform = YYYY  # pseudocode

    def prepare_data(self):
       # add to subjects
       for ds in self.config["data_list"]:
           subject = tio.Subject(
                    source=tio.ScalarImage(ds["source_fn"]),
                    target=tio.LabelMap(ds["target_fn"]),
           )
           self.subjects.append(subject)

    def setup(self, stage: Optional[str] = None):
        num_subjects = len(self.subjects)
        num_val_subjects = int(round(num_subjects * self.config["train_val_ratio"]))
        num_train_subjects = num_subjects - num_val_subjects
        splits = num_train_subjects, num_val_subjects
        train_subjects, val_subjects = random_split(self.subjects, splits)
        self.val_set = tio.SubjectsDataset(val_subjects, transform=self.preproc)
        train_set = tio.SubjectsDataset(train_subjects, transform=self.transform)

        train_sampler = tio.data.UniformSampler(self.config['patch_size'])
        self.train_set = tio.Queue(
            train_set, 
            sampler=train_sampler,
            num_workers=12,
            max_length=600,
            samples_per_volume=6
        )

   def train_dataloader(self):
        return DataLoader(self.train_set, shuffle=True)

# define a model and trainer
dm = myDataModule(config)
model = Model()
trainer.fit(model, datamodule = dm)

Actual outcome

this is just a pseudocode

Error messages

Traceback (most recent call last):
  File "/mnt/eternus/users/Jianxu/projects/mmv_im2im/mmv_im2im/bin/run_im2im.py", line 86, in main
    exe.run_training()
  File "/mnt/eternus/users/Jianxu/projects/mmv_im2im/mmv_im2im/proj_trainer.py", line 75, in run_training
    trainer.fit(model=self.model, datamodule=self.data)
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 768, in fit
    self._call_and_handle_interrupt(
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 719, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/spawn.py", line 78, in launch
    mp.spawn(
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 179, in start_processes
    process.start()
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
  File "/mnt/data/ISAS.DE/jianxu.chen/anaconda3/envs/im2im/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 543, in __getstate__
    raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
NotImplementedError: ('{} cannot be pickled', '_MultiProcessingDataLoaderIter')

[ERROR:  98 2022-05-29 23:06:32,892] =============================================
[ERROR:  99 2022-05-29 23:06:32,892] 

('{} cannot be pickled', '_MultiProcessingDataLoaderIter')

Expected outcome

I hope to use torchio dataloader in a multi-gpu training script

System info

Platform:   Linux-5.4.0-90-generic-x86_64-with-glibc2.17
TorchIO:    0.18.76
PyTorch:    1.8.2
SimpleITK:  2.1.1 (ITK 5.2)
NumPy:      1.21.2
Python:     3.8.12 (default, Oct 12 2021, 13:49:34) 
[GCC 7.5.0]
snipdome commented 2 years ago

More than a bug, should be considered as a request for enhancement. It requires to review the torchio code and make it able to "pickled", so that can be sent to the subprocesses of pytorch-lightning.

Have you tried using another strategy for multi-gpu, like "dpp" or "deepspeed" with pytorch-lightning? I was able to start a training in both cases. The only drawback is that there number of dataloaders/queue is duplicated for each of the processes that are created this way (which is linked to the number of gpus)

EddyJens commented 2 years ago

I'm having the same issue, when using: sampler = tio.data.LabelSampler(patch_size=96, label_name="Label", label_probabilities={0:0.4, 1:0.6})

and

train_patches_queue = tio.Queue( training_set, max_length=40, samples_per_volume=5, sampler=sampler, num_workers=8 )

val_patches_queue = tio.Queue( validation_set, max_length=40, samples_per_volume=5, sampler=sampler, num_workers=8 )

in mine DataLoader:

batch_size = 2

train_loader = torch.utils.data.DataLoader(train_patches_queue, batch_size=batch_size, num_workers=8) val_loader = torch.utils.data.DataLoader(val_patches_queue, batch_size=batch_size, num_workers=8)

I'm using a single GPU with 6GB

fepegar commented 8 months ago

The tutorial works fine with Lightning. If anyone is having this issue, can you please share a minimal reproducible example?