yang-song / score_inverse_problems

Official repo for "Solving Inverse Problems in Medical Imaging with Score-Based Generative Models"
214 stars 26 forks source link

Host memory leak during training #6

Open cobalamin opened 2 years ago

cobalamin commented 2 years ago

I'm experiencing issues with host memory usage continually increasing during training, until eventually my machine freezes up or the process is killed due to out-of-memory (I have 32GB available). Everything else about training seems to be working fine until it crashes (after around 5000 iterations), and GPU memory is also fine as usage of it is completely constant. I've tried several versions of jax/jaxlib/flax but there doesn't seem to be any change with this. I've attached the output of pip freeze in my virtualenv.

Any clues what could be causing this? I searched for JAX memory leaks on Google/StackOverflow, but didn't find anything that seemed useful/related.

pip-environment.txt .

tianzhijiaoziA commented 2 years ago

hi, I am a student of sysu, GPU cannot be used in this jax version, it is better to use tpu, and it is better to use video memory >=48G after testing, A100, jaxlib1.69-1.73 is better, the first time I tried jax framework The problem has been troubled for a long time, I hope it can help you

cobalamin commented 2 years ago

Hi, thank you for your comment. I'm unable to use a TPU and only have GPUs available for training. GPU memory is not the issue for me in this problem, it's host memory.

I was able to sort of work around my issue by making my swapfile ridiculously large (32GB), as it seems the increase in memory usage eventually does stop. It still seems to me that there's a memory leak problem, perhaps memory increases up to the point where all examples have been seen once (are they all loaded into and kept in memory?)