huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.89k stars 26.78k forks source link

large memory usage when resuming training from a checkpoint #11317

Closed dorost1234 closed 3 years ago

dorost1234 commented 3 years ago

Environment info

Who can help

@sgugger @patrickvonplaten, @patil-suraj

Information

Hi I am training t5-base model on mnli dataset, with batch size = 128, the training works fine, but the moment, I want to resume from a checkpoint, then I will get a memory issue, so I observe large memory usage when it is resuming the training.

Expected behavior

resuming from a checkpoint and training, should take equal amount of memory

Error Stack

Traceback (most recent call last):
  File "run_seq2seq.py", line 671, in <module>
    main()
  File "run_seq2seq.py", line 629, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/users/dara/dev/codes/seq2seq/third_party/trainers/trainer.py", line 329, in train
    tr_loss += self.training_step(model, inputs)
  File "/users/dara/libs/anaconda3/envs/test1/lib/python3.7/site-packages/transformers/trainer.py", line 1486, in training_step
    loss = self.compute_loss(model, inputs)
  File "/users/dara/libs/anaconda3/envs/test1/lib/python3.7/site-packages/transformers/trainer.py", line 1518, in compute_loss
    outputs = model(**inputs)
  File "/users/dara/libs/anaconda3/envs/test1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/users/dara/dev/codes/seq2seq/third_party/models/t5/modeling_t5.py", line 1762, in forward
    lang=lang
  File "/users/dara/libs/anaconda3/envs/test1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/users/dara/dev/codes/seq2seq/third_party/models/t5/modeling_t5.py", line 1115, in forward
    task=task
  File "/users/dara/libs/anaconda3/envs/test1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/users/dara/dev/codes/seq2seq/third_party/models/t5/modeling_t5.py", line 752, in forward
    output_attentions=output_attentions,
  File "/users/dara/libs/anaconda3/envs/test1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/users/dara/dev/codes/seq2seq/third_party/models/t5/modeling_t5.py", line 653, in forward
    output_attentions=output_attentions,
  File "/users/dara/libs/anaconda3/envs/test1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/users/dara/dev/codes/seq2seq/third_party/models/t5/modeling_t5.py", line 518, in forward
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
  File "/users/dara/dev/codes/seq2seq/third_party/models/t5/modeling_t5.py", line 501, in project
    hidden_states = shape(proj_layer(key_value_states))
  File "/users/dara/libs/anaconda3/envs/test1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/users/dara/libs/anaconda3/envs/test1/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 94, in forward
    return F.linear(input, self.weight, self.bias)
  File "/users/dara/libs/anaconda3/envs/test1/lib/python3.7/site-packages/torch/nn/functional.py", line 1753, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: CUDA out of memory. Tried to allocate 48.00 MiB (GPU 0; 23.70 GiB total capacity; 21.38 GiB already allocated; 41.69 MiB free; 22.18 GiB reserved in total by PyTorch)
  0%| 

Thanks for your help and suggestions.

LysandreJik commented 3 years ago

Similar issue to https://github.com/huggingface/transformers/issues/11294

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.