young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.38k stars 254 forks source link

optimizer.accumulate_ gradient_ steps Will the related changes to this configuration increase the usage of graphics memory? #36

Closed joytianya closed 1 year ago

joytianya commented 1 year ago

for llama 7b on tpu v3-8, when --optimizer.accumulate_gradient_steps=1, it is normal, but --optimizer.accumulate_gradientsteps=2, it occurs oom optimizer.accumulate gradient_ steps Will the related changes to this configuration increase the usage of graphics memory? Do you have any good solutions?

python3 -m EasyLM.models.llama.llama_train
 --mp_mesh_dim='4,1'
--optimizer.accumulate_gradient_steps=1
--fsdp=True 
 ...
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 23.43G of 15.48G hbm. Exceeded hbm capacity by 7.95G.
young-geng commented 1 year ago

Gradient accumulation requires extra memory to store the accumulated gradient, which has the same size as the parameters.