facebookresearch / fastMRI

A large-scale dataset of both raw MRI measurements and clinical MRI images.
https://fastmri.org
MIT License
1.32k stars 374 forks source link

Training abruptly crashes on single GPU #318

Open pranavsinghps1 opened 1 year ago

pranavsinghps1 commented 1 year ago

While working with the knee dataset on a VarNet from Pytorch-lighting's library and using the FastMriDataModule data-loaders, I observed that the training is unstable and crashes fairly often. I tried looking for similar issues within this repo but couldn't find any. I looked up PyTorch's forum to check for the same and observed such an issue is encountered when the data loader doesn't work well with multiprocessing link (https://github.com/pytorch/pytorch/issues/8976) -- they recommended using workers=0 which did stabilize my training for some time but after a while it crashes as well.

lightning    1.8.6
torch          2.0.1

File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data data = self._data_queue.get(timeout=timeout) File "/ext3/miniconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get return _ForkingPickler.loads(res) File "/ext3/miniconda3/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 307, in rebuild_storage_fd fd = df.detach() File "/ext3/miniconda3/lib/python3.10/multiprocessing/resource_sharer.py", line 57, in detach with _resource_sharer.get_connection(self._id) as conn: File "/ext3/miniconda3/lib/python3.10/multiprocessing/resource_sharer.py", line 86, in get_connection c = Client(address, authkey=process.current_process().authkey) File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 508, in Client answer_challenge(c, authkey) File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 752, in answer_challenge message = connection.recv_bytes(256) # reject large message File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 216, in recv_bytes buf = self._recv_bytes(maxlength) File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 414, in _recv_bytes buf = self._recv(4) File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 379, in _recv chunk = read(handle, remaining) ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/scratch/ps4364/fmri2020/varnet_l1_2/unet_knee_sc.py", line 192, in run_cli() File "/scratch/ps4364/fmri2020/varnet_l1_2/unet_knee_sc.py", line 188, in run_cli cli_main(args) File "/scratch/ps4364/fmri2020/varnet_l1_2/unet_knee_sc.py", line 72, in cli_main trainer.fit(model, datamodule=data_module) File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit call._call_and_handle_interrupt( File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl self._run(model, ckpt_path=self.ckpt_path) File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1098, in _run results = self._run_stage() File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_stage self._run_train() File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1200, in _run_train self.fit_loop.run() File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, *kwargs) File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance self._outputs = self.epoch_loop.run(self._data_fetcher) File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(args, kwargs) File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 188, in advance batch = next(data_fetcher) File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 184, in next return self.fetching_function() File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 265, in fetching_function self._fetch_next_batch(self.dataloader_iter) File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 280, in _fetch_next_batch batch = next(iterator) File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/supporters.py", line 568, in next return self.request_next_batch(self.loader_iters) File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/supporters.py", line 580, in request_next_batch return apply_to_collection(loader_iters, Iterator, next) File "/ext3/miniconda3/lib/python3.10/site-packages/lightning_utilities/core/apply_func.py", line 51, in apply_to_collection return function(data, *args, **kwargs) File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in next data = self._next_data() File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data idx, data = self._get_data() File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1294, in _get_data success, data = self._try_get_data() File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data data = self._data_queue.get(timeout=timeout) File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler _error_if_any_worker_fails() RuntimeError: DataLoader worker (pid 3489789) is killed by signal: Killed.

mmuckley commented 1 year ago

Hello @pranavsinghps1, this is a confusing error. I don't see a single line in the trace that mentions fastMRI. Are you sure there isn't an issue with your install?

Also, we don't actually test VarNet with the single coil data - it's really meant for multicoil with a batch size of 1. Is there a reference that shows VarNet works for single coil that you're trying to reproduce?

pranavsinghps1 commented 1 year ago

I see, Thank you for your prompt response. I will try to realign with the requirements mentioned here (https://github.com/facebookresearch/fastMRI/blob/main/setup.cfg)

As for the use of VarNet for single coil reconstruction -- I did see that in [1], it is mentioned that VarNet is exclusively used for multicoil reconstruction while U-Net for both -- is there a rationale for this? I was trying to figure out the same. For my VarNet I have removed the sensitivity net and just using the Vanila VarNet with ResNet 18 backbone.

[1] Sriram, Anuroop, et al. "End-to-end variational networks for accelerated MRI reconstruction." Medical Image Computing and Computer Assisted Intervention–MICCAI 2020: 23rd International Conference, Lima, Peru, October 4–8, 2020, Proceedings, Part II 23. Springer International Publishing, 2020.

mmuckley commented 1 year ago

Hello @pranavsinghps1, the main innovation of that paper is the end-to-end aspect where the model estimates both the sensitivity maps and the final image. In non-E2E VarNets, the sensitivity maps are precomputed via another method (such as ESPiRIT). Those methods are not end-to-end.

However, in the single-coil case there are no sensitivities, so you just have a regular VarNet.

We never prioritized the development of a single-coil VarNet because in the real world, all MRI scanners are multicoil. There are enormous benefits of multicoil over single coil in terms of SNR and image quality. The single-coil data is only a sort of toy setting for interested people initially getting into the area, but only works done on the multi-coil data are likely to have any impact on real-world scanners.

pranavsinghps1 commented 1 year ago

Thank you @mmuckley for the detailed information on this: I had one question: why multi-coil is trained with a batch size of 1 ?

Update on the issue: rewriting the dataloaders using SliceDataset solved the issue.

mmuckley commented 1 year ago

Hello @pranavsinghps1, the main reason is that many of the multicoil volumes have different matrix sizes for the data. With the VarNet we need to do data consistency on the raw data, so there is no way to do simple batching. In the end we made the VarNet large enough that it used all of 1 GPU's memory, and so we found that batch size of 1, with a large model, was the most effective training strategy.

As for the Issue, could you post more details of your solution? If there is no issue with the core repository, please close the issue.