Closed StellaAthena closed 3 years ago
Hi @StellaAthena,
I expect you are encountering this error because of having two Jax instances running at the same time (e.g. a notebook with a kernel running in the background and running the expts from the command line). By default when you use Jax, 90% of the GPU memory is preallocated and a 2nd process will fail when trying to allocate the memory again. To verify that this is the issue, check the GPU memory usage before running the script. If it is already ~90%, the fix is to have only one instance per GPU running at a time or alternatively to change the memory allocation environment variables here with e.g. export XLA_PYTHON_CLIENT_PREALLOCATE=false
Thanks! Worked like a charm
Thanks! Worked like a charm
None of the experiments from the paper will run for me, and give
I can write my own code using the package without a problem, but your experiments won't work.