Closed Emerald01 closed 7 months ago
cc @kashif have you ever experienced with this?
@Emerald01 could you share your zero2 config? do you use cpu offloading? I have the same problem as it goes out of memory after some steps with Mixtral. My env: 8 A-100 GPUS
Hi !
Can you try to clear the cuda cache between each training step?
You could modify the DPOTrainer source code to overwrite def training_step()
method: https://github.com/huggingface/transformers/blob/e9476832942a19cf99354776ef112babc83c139a/src/transformers/trainer.py#L2848 and call torch.cuda.empty_cache()
after each step together with gc.collect()
@younesbelkada that works!
oh nice!
cc @muellerz do you know if Trainer properly handles torch.cuda.empty_cache()
affter each training step? Perhaps worth making a PR on transformers side? Let me know if you want me to have a look as well
@younesbelkada I believe transformer does not properly clear the cache after each training step, after your suggestion, I did the empty cache and gc collection, compared to the previous stepwise growing memory, now it becomes almost a constant
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
loss_step = super().training_step(model, inputs)
torch.cuda.empty_cache()
gc.collect()
return loss_step
The following is the current GPU memory running the same script
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Hi,
I am DPO training a checkpoint of Mixtral-8x7B-Instruct, from the previous supervised finetune.
I mainly followed this script https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py with 8 H100 GPUs, flash attn and deepspeed zero 2, everything looks good but I notice that the memory consumption has stepwise growing.
Any ideas why it is not a constant? Look like there is some memory garbage collection issues?