huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.14k stars 1.28k forks source link

memory usage of DPO trainer seems stepwise growing with time #1377

Closed Emerald01 closed 7 months ago

Emerald01 commented 8 months ago

Hi,

I am DPO training a checkpoint of Mixtral-8x7B-Instruct, from the previous supervised finetune.

I mainly followed this script https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py with 8 H100 GPUs, flash attn and deepspeed zero 2, everything looks good but I notice that the memory consumption has stepwise growing.

Any ideas why it is not a constant? Look like there is some memory garbage collection issues?

Screenshot 2024-02-27 at 8 00 52 PM
younesbelkada commented 8 months ago

cc @kashif have you ever experienced with this?

saeedkhaki92 commented 8 months ago

@Emerald01 could you share your zero2 config? do you use cpu offloading? I have the same problem as it goes out of memory after some steps with Mixtral. My env: 8 A-100 GPUS

younesbelkada commented 8 months ago

Hi ! Can you try to clear the cuda cache between each training step? You could modify the DPOTrainer source code to overwrite def training_step() method: https://github.com/huggingface/transformers/blob/e9476832942a19cf99354776ef112babc83c139a/src/transformers/trainer.py#L2848 and call torch.cuda.empty_cache() after each step together with gc.collect()

Emerald01 commented 8 months ago

@younesbelkada that works!

younesbelkada commented 8 months ago

oh nice! cc @muellerz do you know if Trainer properly handles torch.cuda.empty_cache() affter each training step? Perhaps worth making a PR on transformers side? Let me know if you want me to have a look as well

Emerald01 commented 8 months ago

@younesbelkada I believe transformer does not properly clear the cache after each training step, after your suggestion, I did the empty cache and gc collection, compared to the previous stepwise growing memory, now it becomes almost a constant

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        loss_step = super().training_step(model, inputs)
        torch.cuda.empty_cache()
        gc.collect()
        return loss_step

The following is the current GPU memory running the same script

Screenshot 2024-03-06 at 9 30 09 AM
github-actions[bot] commented 7 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.