junyuchen245 / TransMorph_Transformer_for_Medical_Image_Registration

TransMorph: Transformer for Unsupervised Medical Image Registration (PyTorch)
MIT License
446 stars 76 forks source link

Keep getting Runtime error of out of memory while reproducing the results on IXI datasets. #5

Closed bhosalems closed 2 years ago

bhosalems commented 2 years ago

RuntimeError: CUDA out of memory. Tried to allocate 420.00 MiB (GPU 1; 11.91 GiB total capacity; 10.87 GiB already allocated; 316.25 MiB free; 11.07 GiB reserved in total by PyTorch)

I keep getting the above error. I tried freeing the cache and tried to print out the memory usage summary, but don't understand what does each type mean,

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 1                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 1            |        cudaMalloc retries: 1         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   10656 MB |   10932 MB |   19094 MB |    8437 MB |
|       from large pool |   10627 MB |   10903 MB |   19040 MB |    8412 MB |
|       from small pool |      29 MB |      31 MB |      53 MB |      24 MB |
|---------------------------------------------------------------------------|
| Active memory         |   10656 MB |   10932 MB |   19094 MB |    8437 MB |
|       from large pool |   10627 MB |   10903 MB |   19040 MB |    8412 MB |
|       from small pool |      29 MB |      31 MB |      53 MB |      24 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   11328 MB |   11328 MB |   11328 MB |       0 B  |
|       from large pool |   11294 MB |   11294 MB |   11294 MB |       0 B  |
|       from small pool |      34 MB |      34 MB |      34 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |  322610 KB |     940 MB |    4760 MB |    4445 MB |
|       from large pool |  317682 KB |     936 MB |    4707 MB |    4396 MB |
|       from small pool |    4928 KB |       4 MB |      53 MB |      48 MB |
|---------------------------------------------------------------------------|
| Allocations           |     449    |     456    |     845    |     396    |
|       from large pool |     192    |     198    |     405    |     213    |
|       from small pool |     257    |     258    |     440    |     183    |
|---------------------------------------------------------------------------|
| Active allocs         |     449    |     456    |     845    |     396    |
|       from large pool |     192    |     198    |     405    |     213    |
|       from small pool |     257    |     258    |     440    |     183    |
|---------------------------------------------------------------------------|
| GPU reserved segments |     107    |     107    |     107    |       0    |
|       from large pool |      90    |      90    |      90    |       0    |
|       from small pool |      17    |      17    |      17    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      49    |      50    |     345    |     296    |
|       from large pool |      37    |      37    |     203    |     166    |
|       from small pool |      12    |      14    |     142    |     130    |
|===========================================================================|

I was specifically running train_TransMorph.py. One suggestion was to reduce the batch size, but it's already set to 1. It might be possible to delete and collect the memory of unused variables and a few other things suggested in the PyTorch forum, but I am not yet confident to changing the training loop.

There's also one issue - the inability of allocating fragmented blocks, fixed in https://github.com/pytorch/pytorch/pull/44742. I am not quite sure in which PyTorch version this is fixed, following up more on that.

However, meanwhile, any thoughts on how to resolve this or any other thoughts on a workaround? Also, is it possible to know how much peak memory would be needed while training?

Thanks

junyuchen245 commented 2 years ago

Hi @lsbmsb,

The TransMorph training script takes around 18 GB of memory for 160x192x224 images, thus 11 GB will not be sufficient to accommodate the training. If you are only interested in reproducing the results, 11 GB should be plenty for inference. You may find the pre-trained model here: https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/main/TransMorph_on_IXI.md#transmorph-variants

You may also try TransMorph-bspl and TransMorph-diff, as they require less memory.

One last thing you could try is changing the Transformer configurations: https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/495e7c9fa76ff885a358c291367d1e41cb5f9052/IXI/TransMorph/models/configs_TransMorph.py#L28-L53

Try using smaller embed_dim, num_heads, or reg_head_chan. You can also turn on use_checkpoint to save memory (this saves ~3GB of memory).

Junyu

bhosalems commented 2 years ago

I updated the reg_head_chan from 96 to 64. It is working now, thanks for your help.

junyuchen245 commented 2 years ago

glad to be of help :)