Closed ManuelEberhardinger closed 1 year ago
I didn't find out why Jax behaves this way and so I switched back to the official MinAtar implementation and now the code runs 10x faster and I also have no memory issues anymore. So I will close this issue for now, but I think it is strange that Jax accumulates the memory without releasing it.
Hi Robert,
Thanks for this awesome library!
I use the gymnax library on CPU to collect data for the Breakout MinAtar environment. I generate thousands of random programs and want to execute them on the env. Somehow the memory accumulates over time so that I get RAM problems. I used the python memory profiler and could detect that, the step function of the env always add about 10MB after each call. Do you know why that is the case? Is this maybe only the case when running Jax on CPU?
I had problems getting Jax and Pytorch running in the same virtual env on Cuda so I thought, I just run gymnax on the CPU to avoid Cuda problems. The memory is also not released in the next step of the loop or at the end of the function..
I used the code from the visualization notebook as a reference.
Thanks a lot for your answer!
Best wishes, Manuel