for llama 7b on tpu v3-8,
when --optimizer.accumulate_gradient_steps=1, it is normal,
but --optimizer.accumulate_gradientsteps=2, it occurs oom
optimizer.accumulate gradient_ steps Will the related changes to this configuration increase the usage of graphics memory?
Do you have any good solutions?
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 23.43G of 15.48G hbm. Exceeded hbm capacity by 7.95G.
for llama 7b on tpu v3-8, when --optimizer.accumulate_gradient_steps=1, it is normal, but --optimizer.accumulate_gradientsteps=2, it occurs oom optimizer.accumulate gradient_ steps Will the related changes to this configuration increase the usage of graphics memory? Do you have any good solutions?