Open qGentry opened 2 hours ago
I've also tried following implementation, inspired by https://github.com/jax-ml/jax/issues/23614#issuecomment-2350773816 with wrapping entire apply_module with flax's custom_vjp, but it doesn't work properly.
Total size device = 21.00781624764204 GB, host = 0.5 GB
.
Here's part of the trace
As you can see, activations are indeed being offloaded during forward pass, but during forward pass they are not loaded back to devices - looks like these offloaded activations are immediately dropped on CPU and GPU activations are saved instead. That's why host memory is only 0.5GB - it is only reserved for activations of one layer.
Also I've noticed that this approach produces wrong loss & grad calculations, but if I'm commenting out all of the "jax.device_put" transfers, everything works as expected again.
Description
Hi guys, I'm very excited with recent activations offloading mechanism introduced in JAX/XLA:GPU but I'm unable to make it work with the scan. My setup is the following - I'm training classic transformer with transformer block scanned over inputs "number of layers" times. I'm also using rematerialization to reduce memory footprint of the model. I basically wrap apply_block function with the jax.remat with "nothing_saveable" policy and then scan this block over inputs to achieve desired behavior - the only activations being saved during forward pass in my case is the residual stream (embeddings) in between scanned block.
With the recent introduction of the "save_and_offload_only_these_names" policy, I thought that it would be enough to mark the output of the scanned block with
jax.ad_checkpoint(x, "output")
and then specifynames_which_can_be_offloaded=["output"]
, but it didn't work.I've implemented repro to showcase what is going on:
First of all, I wanted to ensure that offloading is working in the first place. With
I'm getting following results:
Total size device = 20.26562874764204 GB, host = 0.0 GB
Quite reasonable value.then, I wanted to check how much would it cost to save "h" on GPU, so I set
and getting
Total size device = 35.2968789935112 GB, host = 0.0 GB
This is also totally expected as "h" is f32[32,64,8192,2048] sharded across 8 GPUs which is equals to 16GB per GPU.Ok, let's try to offload "h" and see what happens.
Total size device = 19.75000447779894 GB, host = 16.0 GB
- also totally expected, instead of saving 16GB on GPU, we're offloading activations on host, device memory saved. Also iterations become a lot slower with is also expected.Now we sure that offloading is indeed working properly, I've tried to offload "residual" tensor (output of the scanned block).
Aaaand nothing happens -
Total size device = 20.26562874764204 GB, host = 0.0 GB
, nothing happens, no changes in memory usage, iterations is the same as no offloading at all.System info (python version, jaxlib version, accelerator, etc.)
JAX issue https://github.com/jax-ml/jax/issues/23869