sct-pipeline / contrast-agnostic-softseg-spinalcord

Contrast-agnostic spinal cord segmentation project with softseg
MIT License
4 stars 3 forks source link

Memory issue with training on Monai with large datasets #59

Open louisfb01 opened 1 year ago

louisfb01 commented 1 year ago

Hey All! So as discussed with @naga-karthik I've had some issues with the "aggregated training".

I am running into memory issues where no fixes seem to work.

I am using Naga's training script and tried pretty much all solutions I could find, but I always get a RuntimeError('Pin memory thread exited unexpectedly'). This can be fixed by using 0 workers in these lines, but makes the training super slow (15+ mins per epoch).

I tried with the exact same configurations as Naga too. The only difference is the amount of images. Plus, it works when removing most images from the dataset.json file (so when working with a much smaller set).

I am still investigating this issue...

naga-karthik commented 1 year ago

This is indeed strange. Can you also post the exact arguments that you're using with main.py?

louisfb01 commented 1 year ago

Pretty much no argument, just correct paths to data and default unet, 100 epochs, and your hard coded values.

jcohenadad commented 1 year ago

@louisfb01 can you please list in this issue thread the various discussions on this topic in MONAI GH, slack, forums, etc.

louisfb01 commented 1 year ago

The source seems to be a memory leak when storing to much data from a dataloader, from this. But in his case it happens running the test set whereas here it happens when training.

Solutions tried:

- Adding this line to the training script did not work torch.multiprocessing.set_sharing_strategy('file_system'), as in this issue.

The current workaround is to use num_workes=0 as in this issue and this one, and more.

New temporary fix:

jcohenadad commented 1 year ago

@louisfb01 can you please

louisfb01 commented 1 year ago

Here is more information about the environment (python version 3.9.17) and the output I get from running main.py (training with PyTorch, PyTorch lightning, MONAI).

STDOUT ``` (monai_training) lobouz@romane:~/github/contrast-agnostic-softseg-spinalcord/monai$ CUDA_VISIBLE_DEVICES=3 python main.py -m unet -nspv 4 -ncv 1 -initf 8 -bs 4 -lr 1e-3 -cve 4 -stp -epb Global seed set to 42 2023-07-18 10:01:13.007 | INFO | __main__:main:472 - Training on fold 1 out of 1 folds! /home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:196: UserWarning: Attribute 'loss_function' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss_function'])`. rank_zero_warn( wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin wandb: Tracking run with wandb version 0.15.5 wandb: Run data is saved locally in /home/GRAMES.POLYMTL.CA/lobouz/contrast-agnostic/saved_models/wandb/run-20230718_100119-9xlhu2uf wandb: Run `wandb offline` to turn off syncing. wandb: Syncing run unet_nf=8_nrs=2_lr=0.001_20230718-1001 wandb: ⭐️ View project at https://wandb.ai/whats_ai/contrast-agnostic wandb: 🚀 View run at https://wandb.ai/whats_ai/contrast-agnostic/runs/9xlhu2uf GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/monai/utils/deprecate_utils.py:321: FutureWarning: monai.transforms.io.dictionary LoadImaged.__init__:image_only: Current default value of argument `image_only=False` has been deprecated since version 1.1. It will be changed to `image_only=True` in version 1.3. warn_deprecated(argname, msg, warning_category) Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 476/476 [01:45<00:00, 4.53it/s] Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:37<00:00, 3.20it/s] Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:10<00:00, 2.78it/s] You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3] | Name | Type | Params ----------------------------------------------- 0 | net | UNet | 1.2 M 1 | loss_function | SoftDiceLoss | 0 ----------------------------------------------- 1.2 M Trainable params 0 Non-trainable params 1.2 M Total params 4.809 Total estimated model params size (MB) Sanity Checking DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00, 1.34s/it]Current epoch: 0 Average Soft Dice (VAL): 0.0071 Average Hard Dice (VAL): 0.0020 Best Average Soft Dice: 0.0071 at Epoch: 0 ---------------------------------------------------- Epoch 0: 3%|████▌ | 14/477 [00:23<12:59, 1.68s/it, v_num=u2uf]Exception in thread Thread-28: Traceback (most recent call last): File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/threading.py", line 980, in _bootstrap_inner self.run() File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/threading.py", line 917, in run self._target(*self._args, **self._kwargs) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 51, in _pin_memory_loop do_one_step() File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 28, in do_one_step r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/multiprocessing/queues.py", line 122, in get return _ForkingPickler.loads(res) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 307, in rebuild_storage_fd fd = df.detach() File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/multiprocessing/resource_sharer.py", line 58, in detach return reduction.recv_handle(conn) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/multiprocessing/reduction.py", line 189, in recv_handle return recvfds(s, 1)[0] File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/multiprocessing/reduction.py", line 164, in recvfds raise RuntimeError('received %d items of ancdata' % RuntimeError: received 0 items of ancdata Traceback (most recent call last): File "/home/GRAMES.POLYMTL.CA/lobouz/github/contrast-agnostic-softseg-spinalcord/monai/main.py", line 605, in main(args) File "/home/GRAMES.POLYMTL.CA/lobouz/github/contrast-agnostic-softseg-spinalcord/monai/main.py", line 516, in main trainer.fit(pl_model) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 531, in fit call._call_and_handle_interrupt( File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt return trainer_fn(*args, **kwargs) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 570, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 975, in _run results = self._run_stage() File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1018, in _run_stage self.fit_loop.run() File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 201, in run self.advance() File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 354, in advance self.epoch_loop.run(self._data_fetcher) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run self.advance(data_fetcher) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 189, in advance batch = next(data_fetcher) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/loops/fetchers.py", line 136, in __next__ self._fetch_next_batch(self.dataloader_iter) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/loops/fetchers.py", line 150, in _fetch_next_batch batch = next(iterator) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/utilities/combined_loader.py", line 284, in __next__ out = next(self._iterator) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/utilities/combined_loader.py", line 65, in __next__ out[i] = next(self.iterators[i]) File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 634, in __next__ data = self._next_data() File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data idx, data = self._get_data() File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1290, in _get_data raise RuntimeError('Pin memory thread exited unexpectedly') RuntimeError: Pin memory thread exited unexpectedly ```
Environment details (pip list) ``` (monai_training) lobouz@romane:~$ pip list Package Version ----------------------------- -------------- absl-py 1.1.0 aiohttp 3.8.4 aiosignal 1.3.1 appdirs 1.4.4 astor 0.8.1 asttokens 2.2.1 astunparse 1.6.3 async-timeout 4.0.2 attrs 21.2.0 awscli 1.22.34 backcall 0.2.0 backports.functools-lru-cache 1.6.5 beautifulsoup4 4.11.2 beniget 0.4.1 bids-validator 1.9.9 blinker 1.4 botocore 1.23.34 Brotli 1.0.9 brz-etckeeper 0.0.0 cachetools 5.2.0 certifi 2023.5.7 cffi 1.15.1 chardet 4.0.0 charset-normalizer 3.1.0 click 8.1.3 cmake 3.26.4 colorama 0.4.6 coloredlogs 15.0.1 comm 0.1.3 command-not-found 0.3 commonmark 0.9.1 contourpy 1.1.0 cryptography 3.4.8 csv-diff 1.1 cycler 0.11.0 dbus-python 1.2.18 debugpy 1.6.7 decorator 5.1.1 Deprecated 1.2.13 dictdiffer 0.9.0 dill 0.3.5.1 distlib 0.3.4 distro 1.7.0 distro-info 1.1build1 dnspython 2.1.0 docker 5.0.3 docker-compose 1.29.2 docker-pycreds 0.4.0 dockerpty 0.4.1 docopt 0.6.2 docutils 0.17.1 entrypoints 0.4 executing 1.2.0 filelock 3.12.2 flatbuffers 2.0.7 fonttools 4.40.0 formulaic 0.3.4 frozenlist 1.3.3 fsleyes 1.7.0 fsleyes-props 1.8.2 fsleyes-widgets 0.14.2 fslpy 3.13.0 fsspec 2023.6.0 gast 0.4.0 gdown 4.6.4 gitdb 4.0.10 GitPython 3.1.31 gmpy2 2.1.2 google-auth 2.19.0 google-auth-oauthlib 1.0.0 google-pasta 0.2.0 gpg 1.16.0-unknown grpcio 1.54.2 h5py 3.7.0 httplib2 0.20.2 humanfriendly 10.0 humanize 4.4.0 idna 3.4 imageio 2.22.4 imgaug 0.2.5 importlib-metadata 6.8.0 importlib-resources 6.0.0 interface-meta 1.3.0 iotop 0.6 ipykernel 6.24.0 ipython 8.14.0 ivadomed 2.9.7 jax 0.4.11 jedi 0.18.2 jeepney 0.7.1 Jinja2 3.1.2 jmespath 0.10.0 joblib 1.3.0 jsonschema 3.2.0 jupyter_client 8.3.0 jupyter_core 5.3.1 keras 2.12.0 Keras-Preprocessing 1.1.2 keyring 23.5.0 kiwisolver 1.4.4 launchpadlib 1.10.16 lazr.restfulclient 0.14.4 lazr.uri 1.0.6 libclang 14.0.1 lightning-utilities 0.9.0 lit 16.0.6 loguru 0.7.0 Markdown 3.3.6 MarkupSafe 2.1.3 matplotlib 3.7.2 matplotlib-inline 0.1.6 ml-dtypes 0.1.0 monai 1.2.0 monai-weekly 1.2.dev2311 more-itertools 8.10.0 mpmath 1.3.0 multidict 6.0.4 nest-asyncio 1.5.6 netifaces 0.11.0 networkx 3.1 nibabel 5.1.0 num2words 0.5.12 numpy 1.25.0 nvidia-cublas-cu11 11.10.3.66 nvidia-cuda-cupti-cu11 11.7.101 nvidia-cuda-nvrtc-cu11 11.7.99 nvidia-cuda-runtime-cu11 11.7.99 nvidia-cudnn-cu11 8.5.0.96 nvidia-cufft-cu11 10.9.0.58 nvidia-curand-cu11 10.2.10.91 nvidia-cusolver-cu11 11.4.0.1 nvidia-cusparse-cu11 11.7.4.91 nvidia-nccl-cu11 2.14.3 nvidia-nvtx-cu11 11.7.91 oauthlib 3.2.0 onnxruntime 1.13.1 opt-einsum 3.3.0 osfclient 0.0.5 packaging 23.1 pandas 2.0.3 parso 0.8.3 pathtools 0.1.2 pexpect 4.8.0 pickleshare 0.7.5 Pillow 10.0.0 pip 23.1.2 platformdirs 3.8.0 ply 3.11 pooch 1.7.0 promise 2.3 prompt-toolkit 3.0.39 protobuf 3.20.3 psutil 5.9.5 ptyprocess 0.7.0 pure-eval 0.2.2 pyasn1 0.4.8 pyasn1-modules 0.2.8 pybids 0.15.5 pycparser 2.21 Pygments 2.15.1 PyGObject 3.42.1 PyJWT 2.3.0 pymacaroons 0.13.0 PyNaCl 1.5.0 PyOpenGL 3.1.6 pyparsing 3.0.9 pyrsistent 0.18.1 PySocks 1.7.1 python-apt 2.4.0+ubuntu1 python-dateutil 2.8.2 python-dotenv 0.19.2 python-magic 0.4.24 pythran 0.10.0 pytorch-ignite 0.4.11 pytorch-lightning 2.0.4 pytz 2023.3 PyWavelets 1.4.1 PyYAML 6.0 pyzmq 25.1.0 requests 2.31.0 requests-oauthlib 1.3.1 requests-toolbelt 0.9.1 rich 12.6.0 roman 3.3 rsa 4.8 s3transfer 0.5.0 scikit-image 0.19.3 scikit-learn 1.3.0 scipy 1.11.1 screen-resolution-extra 0.0.0 seaborn 0.12.1 SecretStorage 3.3.1 sentry-sdk 1.21.1 setproctitle 1.3.2 setuptools 68.0.0 shellingham 1.5.0 shortuuid 1.0.11 SimpleITK 2.2.1 six 1.16.0 smmap 3.0.5 sos 4.4 soupsieve 2.4 SQLAlchemy 1.3.24 ssh-import-id 5.11 stack-data 0.6.2 sympy 1.12 systemd-python 234 tensorboard 2.12.3 tensorboard-data-server 0.7.0 tensorboard-plugin-wit 1.8.1 tensorflow 2.12.0 tensorflow-estimator 2.12.0 tensorflow-io-gcs-filesystem 0.26.0 termcolor 1.1.0 texttable 1.6.4 threadpoolctl 3.1.0 tifffile 2022.10.10 torch 2.0.0 torchaudio 2.0.1 torchio 0.18.86 torchmetrics 0.11.4 torchvision 0.15.1 tornado 6.3.2 tqdm 4.65.0 traitlets 5.9.0 triton 2.0.0 typer 0.7.0 typing_extensions 4.7.1 tzdata 2023.3 ufw 0.36.1 unattended-upgrades 0.1 urllib3 2.0.3 virtualenv 20.13.0+ds wadllib 1.3.6 wandb 0.15.5 wcwidth 0.2.6 websocket-client 1.2.3 Werkzeug 2.1.2 wheel 0.40.0 wrapt 1.14.1 wxPython 4.0.7 xkit 0.0.0 yarl 1.9.2 zipp 3.15.0 ```
louisfb01 commented 1 year ago

Updated the answer above with a new temporary fix.

Adding this line to the training script did work torch.multiprocessing.set_sharing_strategy('file_system'), as in this issue and here. I believe I previously did not add it in the right section of the code by mistake. It now works as expected!

It is not super clean but it at least allows us to train normally for now.

jcohenadad commented 1 year ago

It is not super clean

Why is that?

louisfb01 commented 1 year ago

It is not super clean

Why is that?

Nevermind on that. I didn't like the idea of having to add this line and thought it was a "hard-coded" fix to a PyTorch issue, but it seems like a normal behaviour after further research. This issue can be closed with the solution of adding the torch.multiprocessing.set_sharing_strategy('file_system') line to the beginning of your training script.

jcohenadad commented 1 year ago

I just realized this magic syntax also fixed an issue for me in the past 😅 https://github.com/jcohenadad/model-seg-ms-mp2rage-monai/commit/450f72e471b33e668109f55195db92811ef15d78

naga-karthik commented 1 year ago

Is there any deeper explanation anywhere as to why this fix is working?

louisfb01 commented 1 year ago

Is there any deeper explanation anywhere as to why this fix is working?

This is what I found:

torch.multiprocessing is a wrapper around the native multiprocessing module. It registers custom reducers, that use shared memory to provide shared views on the same data in different processes. Once the tensor/storage is moved to shared_memory (see sharememory()), it will be possible to send it to other processes without making any copies. (from torch documentation)

And from what I understand, this memory issue comes from using the CacheDataset and it has to do with pytorch's sharing strategy. The function torch.multiprocessing.set_sharing_strategy('file_system') will enable a flag to not create extra file descriptors, more information here.

It still seems to be a temporary fix to "having a high enough limit" in our system. This can be done increasing ulimit, but not in our case since we are limited as a user in Romane.

Quote from PyTorch doc:

Still, if your system has high enough limits, and file_descriptor is a supported strategy, we do not recommend switching to this one.

This one referring to using the file_system flag.

naga-karthik commented 1 year ago

And from what I understand, this memory issue comes from using the CacheDataset and it has to do with pytorch's sharing strategy.

ah this is a good point! Monai also provides PyTorch's native Dataset class. Could you please try using that once and remove the multiprocessing fix to see if it is really CacheDataset that's the culprit? (you just have switch to Dataset and use the right arguments, just a line of code)

louisfb01 commented 1 year ago

I tried using MONAI's Dataset class instead of CacheDataset.

Using Dataset does not allow us to remove torch.multiprocessing.set_sharing_strategy('file_system').

louisfb01 commented 1 year ago

New error even with torch.multiprocessing.set_sharing_strategy('file_system').

Happens both when using MONAI's Dataset class and CacheDataset.

The error occurs with all the aggregated datasets (approx. 7k total images from train/val/test) and it only happens at 73% (with CacheDataset) and 75% (with Dataset) of the first epoch:

23 Epoch 0:  75%|███████████████████████████████████████████████████████████████████████████████████████████████████████▍                                  | 1069/1426 [1:01:50<20:39,  3.47s/it, v_num=ymin]
24 Exception in thread Thread-7:
25 Traceback (most recent call last):
26   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/threading.py", line 980, in _bootstrap_inner
27     self.run()
28   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/threading.py", line 917, in run
29     self._target(*self._args, **self._kwargs)
30   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 51, in _pin_memory_loop
31     do_one_step()
32   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 28, in do_one_step
33     r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
34   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/multiprocessing/queues.py", line 122, in get
35     return _ForkingPickler.loads(res)
36   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 324, in rebuild_storage_filename
37     storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
38 RuntimeError: unable to mmap 160 bytes from file </torch_2814340_3051998398_49246>: Cannot allocate memory (12)
39 Traceback (most recent call last):
40   File "/home/GRAMES.POLYMTL.CA/lobouz/github/contrast-agnostic-softseg-spinalcord/monai/main.py", line 636, in <module>
41     main(args)
42   File "/home/GRAMES.POLYMTL.CA/lobouz/github/contrast-agnostic-softseg-spinalcord/monai/main.py", line 536, in main
43     trainer.fit(pl_model)
44   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 531, in fit
45     call._call_and_handle_interrupt(
46   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
47     return trainer_fn(*args, **kwargs)
48   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 570, in _fit_impl
49     self._run(model, ckpt_path=ckpt_path)
50   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 975, in _run
51     results = self._run_stage()
52   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1018, in _run_stage
53     self.fit_loop.run()
54   File "/home/GRAMES.POLYMTL.CA/lobouz/miniconda3/envs/monai_training/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 201, in run

Plus, the epoch takes over 1:30 hours to run. I think the next step is to investigate using compute Canada and train with more GPUs and memory. Anyways, as @naga-karthik and I saw, Romane is getting pretty crowded nowadays and it is hard to train when you want.

jcohenadad commented 1 year ago

I think the next step is to investigate using compute Canada and train with more GPUs and memory.

👍

louisfb01 commented 1 year ago

Regarding the issue above, I looked further into it, and it seems nobody can explain it (?). For now, one solution is to put num_workers=0 once again, based on this thread, with no better solution. Reminder that using num_workers=0 is not viable in our case. One epoch is close to 5 hours instead of 1:40 hours with this setting.

Some threads are also still opened, not finding the cause of the issue with another fix (will update the comment once I've tried them):

We have no root cause but this issue disappeared now after we fixed a host mem leak issue in our project. If you also met this problem, you can set the OS mmap limitation much higher as a work around. Or set the worker = 0 in dataloader also helps. From this issue.. -> requires sudo, I don't think it is pertinent.

Another potential solution is to implement our own version of the dataset class and implement this modification to transform the data into torch tensors. -> did not work.

Now working on implementing the code for Compute Canada to train with multiple GPUs and will update if we have the same problem.

naga-karthik commented 2 months ago

closing this as it is not relevant anymore -- I was able to train the model on 11 datasets at the moment

plbenveniste commented 1 month ago

Hi @naga-karthik ! I am re-opening this issue as I am facing the same problem as @louisfb01 when training using monai and large datasets. My training crashes during the validation step even though I am using torch.multiprocessing.set_sharing_strategy('file_system'). Here is my code..

The error I get is the following:

Error message ```console Traceback (most recent call last): File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/threading.py", line 980, in _bootstrap_inner self.run() File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/threading.py", line 917, in run self._target(*self._args, **self._kwargs) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 51, in _pin_memory_loop do_one_step() File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 28, in do_one_step r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/multiprocessing/queues.py", line 122, in get return _ForkingPickler.loads(res) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 324, in rebuild_storage_filename storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size) RuntimeError: unable to mmap 68 bytes from file : Cannot allocate memory (12) Traceback (most recent call last): File "/home/plbenveniste/ms_lesion_agnostic/ms-lesion-agnostic/monai/train_monai_unet_lightning.py", line 823, in main() File "/home/plbenveniste/ms_lesion_agnostic/ms-lesion-agnostic/monai/train_monai_unet_lightning.py", line 815, in main trainer.fit(pl_model) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit call._call_and_handle_interrupt( File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt return trainer_fn(*args, **kwargs) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 987, in _run results = self._run_stage() File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1033, in _run_stage self.fit_loop.run() File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run self.advance() File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance self.epoch_loop.run(self._data_fetcher) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 141, in run self.on_advance_end(data_fetcher) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 295, in on_advance_end self.val_loop.run() File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator return loop_run(self, *args, **kwargs) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 128, in run batch, batch_idx, dataloader_idx = next(data_fetcher) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/loops/fetchers.py", line 133, in __next__ batch = super().__next__() File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__ batch = next(self.iterator) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__ out = next(self._iterator) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/pytorch_lightning/utilities/combined_loader.py", line 142, in __next__ out = next(self.iterators[0]) File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 633, in __next__ data = self._next_data() File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data idx, data = self._get_data() File "/home/plbenveniste/miniconda3/envs/venv_monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1289, in _get_data raise RuntimeError('Pin memory thread exited unexpectedly') RuntimeError: Pin memory thread exited unexpectedly ```

I had to set num_workers to zero in the validation to get the code to work. Is there any cleaner one around this ? (Btw I am running this on this server: GPU/CPU Server - 512 cores-threads – Epyc 3.0GHz – 128G RAM – 4 GPUs (4x Nvidia A6000) and the memory is not reaching its limit.