google-deepmind / jax_privacy

Algorithms for Privacy-Preserving Machine Learning in JAX
Apache License 2.0
87 stars 11 forks source link

Cuda memory issue #10

Closed kamadforge closed 1 year ago

kamadforge commented 1 year ago

Even on a machine with over 30GB of the GPU memory and lowering the batch_size.per_device_per_step to 4, the memory error persists.

Original error: RESOURCE_EXHAUSTED: Failed to allocate request for 176.00MiB (184549376B) on device ordinal 0

What could be the reason? Is there any flag or any parameter that could cause this memory failure?

ahasanpour commented 1 year ago

I got the same error with 24GB GPU. Even with 40GB GPU, I should use export XLA_PYTHON_CLIENT_MEM_FRACTION=0.7 to be able run the code wtih batch size=256. Hope this help!

lberrada commented 1 year ago

Thanks for the comment @ahasanpour. Indeed, this would have been expected to fit in memory. Additional potential solutions are listed at https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html#gpu-memory-allocation.