Open HGangloff opened 1 year ago
Hi HGangloff, Prioritize JIT Compilation:
Compile your code using jax.jit whenever possible to benefit from JAX's optimizations and potentially avoid the RAM issue. Investigate RAdam Implementation:
Explore the RAdam implementation in Optax: https://github.com/deepmind/optax/blob/fc5de3d3951c4dfd87513c6426e30baf505d89ae/optax/_src/transform.py#L685C7-L685C7 Focus on areas that might create large temporary arrays or perform memory-intensive operations. Consider profiling memory usage to pinpoint specific lines or functions causing excessive consumption. Experiment with Alternative Optimizers:
If RAdam's performance is crucial for your research, consider: Modifying RAdam's implementation to reduce memory footprint (if feasible). Exploring alternative optimizers like Yogi, which share similarities with RAdam but might have different memory characteristics. Report to Optax Maintainers:
Share your findings and code examples with the Optax maintainers to bring attention to the issue and potentially contribute to a fix. Additional Considerations:
Memory Profiling: Use tools like jax.profiler or external profilers to track memory usage and identify bottlenecks. Batch Size Adjustment: Experiment with smaller batch sizes to reduce memory requirements per step. Hardware Constraints: Consider available RAM and potential hardware limitations. I'm ready to assist further if you have more questions or require additional guidance. I'll be waiting for your positive response!!!
Hi,
I have my RAM getting used up to overflow when I use
scale_by_radam
gradient transform or equivalentlyoptax.radam
without JIT compiling the code. The problem appears on CPU and GPU but does not appear when I use JIT compilation. The problem does not seem to exist withoptax.adam
.Here is a MWE derived from optax quick start tutorial:
Of course this example is simple enough and does not saturate the RAM before a long time but this issue is really problematic in another particular research project.
The problem seems to be linked with this computation specific to RAdam: https://github.com/deepmind/optax/blob/fc5de3d3951c4dfd87513c6426e30baf505d89ae/optax/_src/transform.py#L685C7-L685C7. But I do not know how to investigate further.
Thanks for your feedback.