huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.01k stars 1.27k forks source link

Deepspeed Zero2 not working when using DPOTrainer #2062

Closed EQ3A2A closed 3 weeks ago

EQ3A2A commented 2 months ago

System Info

Information

Tasks

Reproduction

The accelerate config file I'm using

deepspeed_config.yaml

compute_environment: LOCAL_MACHINE
deepspeed_config:
 gradient_accumulation_steps: 1
 gradient_clipping: 1.0
 offload_optimizer_device: cpu
 offload_param_device: cpu
 zero3_init_flag: true
 zero_stage: 2
distributed_type: DEEPSPEED
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
use_cpu: false

The training script I'm using

train.py


import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
import pdb
import torch
import os
from accelerate import Accelerator

import warnings

model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it",attn_implementation='eager')
ref_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it",attn_implementation='eager')

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")

dataset = load_dataset("json", data_files="dpo_train.json")

training_args = DPOConfig(
    report_to="none",
    output_dir="/data/models/gemma_dpo_checkpoints",
    per_device_train_batch_size=1,  
    num_train_epochs=3,
    logging_dir='/data/logs',
    logging_steps=5,
    save_steps=100,
    max_length = 1225,
    max_prompt_length = 1225,
    save_total_limit=2,
    dataloader_num_workers=4,
    bf16=True,  
)

# load trainer
trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset["train"],
)

# train
trainer.train()

trainer.save_model("/data/models/gemma_dpo")

Run the script with accelerate

accelerate launch --config_file deepspeed_config.yaml train.py

Expected behavior

The Zero2 is not working(Set to Zero0)

image
kechunFIVE commented 1 month ago

mark

TongLiu-github commented 1 month ago

Hi did you solve this problem? Same problem here.

lewtun commented 1 month ago

Hello @EQ3A2A @TongLiu-github can you please share an example that reproduces the error with a public dataset I can test?

TongLiu-github commented 1 month ago

Hello @EQ3A2A @TongLiu-github can you please share an example that reproduces the error with a public dataset I can test?

Thanks for the reply. I solved this problem from: https://github.com/huggingface/accelerate/issues/314#issue-1201142707

lewtun commented 1 month ago

Thanks @TongLiu-github - do I understand correctly that you were experiencing a NCCL timeout error instead?

The reason you are seeing Stage 0 in the logs is because we initialise the initialise the reference model in this stage unless Stage 3 is set by the user: https://github.com/huggingface/trl/blob/2cad48d511fab99ac0c4b327195523a575afcad3/trl/trainer/dpo_trainer.py#L923

In the screenshot below, I compare DDP vs ZeRO-3 and one indeed sees the memory used by the latter is smaller.

Screenshot 2024-09-24 at 10 30 09

If that resolves the issue, feel free to close it.

Joe-Hall-Lee commented 1 month ago

Thanks @TongLiu-github - do I understand correctly that you were experiencing a NCCL timeout error instead?

The reason you are seeing Stage 0 in the logs is because we initialise the initialise the reference model in this stage unless Stage 3 is set by the user:

https://github.com/huggingface/trl/blob/2cad48d511fab99ac0c4b327195523a575afcad3/trl/trainer/dpo_trainer.py#L923

In the screenshot below, I compare DDP vs ZeRO-3 and one indeed sees the memory used by the latter is smaller.

Screenshot 2024-09-24 at 10 30 09

If that resolves the issue, feel free to close it.

Hello @lewtun. So, if I understand correctly, this is just an issue with how the logs are displayed, and Zero2 is actually enabled, right?

lewtun commented 3 weeks ago

Hi @Joe-Hall-Lee yes that's correct: the logs from deepspeed are showing the initialisation of the reference model