Open RylanSchaeffer opened 2 years ago
@romanngg , you were really helpful previously - any thoughts here? Thanks in advance :)
nt.linearize
is essentially a Jacobian-vector product (jax.jvp
), and it's peak memory consumption of the linearized forward pass should be about 2x the peak memory consumption of the forward pass. Then, I believe the costs of the backward passes (jax.vjp
) of the linearized and non-linearized models should also differ by 2X (see https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#how-it-s-made-two-foundational-autodiff-functions or https://openreview.net/pdf?id=ym68T6OoO6L). If you have a way to diagnose the peak memory consumption when you train your model, could you check that it's less than half of your GPU memory?
@romanngg thanks for getting back to me so soon! I'll check the max memory consumption but I don't think that's the reason because we could successfully "manually" perform a forward and backward pass of f_lin
on a single GPU with batch size = 512. By "manually," I mean executing the following alone:
l, g = utils.accumulate_gradient(
jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
**accum_steps)**
I suspect that there might be an odd interaction between f_lin
as constructed by Neural Tangents and the code used in the vision transformer notebook (pasted above).
Another bizarre observation: if we try
l, g = utils.accumulate_gradient(
jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
**accum_steps)**
with the original (non-linearized) model outside the update_fn()
(defined above), we get a OOM error for about 155 MiB, even though the GPU has tons of additional available memory. This problem does not occur when using the linearized model.
Edit: Ignore that last observation. That problem vanished when we reduced the per-GPU batch size..
Here's a self-contained colab that reproduces the issue. https://colab.research.google.com/drive/184moQLq3tjo-wEpc8gD7fXCFguAVDBOm#scrollTo=k4CjYqp5qLvj
We suspect that pmap
might be causing a problem because if we don't use it, the linearized model can train via jax.value_and_grad(loss_fn)
, but once we try pmap(jax.value_and_grad(loss_fn))
, we hit OOM.
Another insight: GPU on Colab breaks, but TPU on Colab is fine
Looks like someone else has a similar problem while using neural tangents, also potentially arising from pmap
https://github.com/google/jax/issues/8585#issuecomment-1061256273
We're trying to fine-tune a linearized Vision Transformer by adapting code from https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb.
We're running into a really puzzling problem: when we load a model, we can train it, and when we linearize it, we can still train the pre-linearized model. However, when we try using the linearized model, we get:
This error emerges regardless of whether we are using 1 GPU or multiple. It also emerges whether we are using a large batch (512) or small (1).
We manually tested that a forward pass raises no error, and that a backward pass raises no error. We suspect that the error might arise from the following code (although we could be wrong!):
Their code:
That function is then called via:
The training loop where the memory error arises:
The above code is all copied from the ViT repo. This is how we linearize the ViT model: