huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.99k stars 1.26k forks source link

DPOTrainer ignores aux_loss for MoE during training because it fails to use router_aux_loss_coef in model config #2197

Closed muupan closed 1 month ago

muupan commented 1 month ago

System Info

Python 3.11.7 trl==0.11.2 transformers==4.45.1 accelerate==0.34.2

Information

Tasks

Reproduction

The documentation says router_aux_loss_coef is used as a coefficient for the auxiliary loss.

This option is enabled by setting output_router_logits=True in the model config (e.g. MixtralConfig). To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter router_aux_loss_coef=... (default: 0.001) in the model config. https://huggingface.co/docs/trl/main/en/dpo_trainer#for-mixture-of-experts-models-enabling-the-auxiliary-loss

However, it seems router_aux_loss_coef is not actually used during training. Instead, the coefficient is always 0, meaning that the auxiliary loss is never used, which I think is a bug as it contradicts the documentation.

Here is my code to reproduce the issue. I set model.config.output_router_logits = True and model.config.router_aux_loss_coef = 0.123 to enable the auxiliary loss.

from accelerate import Accelerator
from datasets import Dataset
from peft import LoraConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

config = DPOConfig(
    output_dir="/tmp/output",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    max_steps=10,
    eval_strategy="steps",
    eval_steps=5,
    save_steps=1000,
    logging_steps=1,
    bf16=True,
    report_to="none",
    disable_tqdm=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map={"": Accelerator().local_process_index},
    load_in_4bit=True,
)
model.config.output_router_logits = True
model.config.router_aux_loss_coef = 0.123

peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.0,
    target_modules="all-linear",
    bias="none",
    task_type="CAUSAL_LM",
)

tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    trust_remote_code=True,
)
tokenizer.pad_token = tokenizer.eos_token

train_dataset = Dataset.from_dict(
    {
        "prompt": ["What is the capital of France?", "What is the capital of Italy?"],
        "chosen": ["Paris", "Rome"],
        "rejected": ["Berlin", "London"],
    }
)
eval_dataset = Dataset.from_dict(
    {
        "prompt": ["What is the capital of Japan?", "What is the capital of China?"],
        "chosen": ["Tokyo", "Beijing"],
        "rejected": ["Seoul", "Shanghai"],
    }
)

trainer = DPOTrainer(
    model=model,
    peft_config=peft_config,
    tokenizer=tokenizer,
    args=config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    ref_model=None,
    max_length=100,
)

trainer.train()

Since there seems no way to know the actual coefficient used in the loss computation, I added a print statement.

diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py
index 3c9ff46..33cfb82 100644
--- a/trl/trainer/dpo_trainer.py
+++ b/trl/trainer/dpo_trainer.py
@@ -1492,6 +1492,8 @@ class DPOTrainer(Trainer):
             metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()

         if self.aux_loss_enabled:
+            print("type(model):", type(model))
+            print("aux_loss_coef", getattr(model.config, "router_aux_loss_coef", 0.0))
             return losses.mean() + getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss, metrics

         return losses.mean(), metrics

When I execute the following command in a machine with two H100s, I can see that the coefficient is 0 during training. The value I specified, 0.123, is used only during evaluation. It seems that the discrepancy comes from the fact that model is wrapped byDeepSpeedEngine during training so it router_aux_loss_coef is not in model.config.

> accelerate launch --use_deepspeed --zero_stage 2 debug_aux_loss.py
The following values were not passed to `accelerate launch` and had defaults used instead:
        `--num_processes` was set to a value of `2`
                More than one GPU was found, enabling multi-GPU training.
                If this was unintended please pass in `--num_processes=1`.
        `--num_machines` was set to a value of `1`
        `--mixed_precision` was set to a value of `'no'`
        `--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
[2024-10-08 13:58:58,842] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
W1008 13:59:03.693000 139990296390528 torch/distributed/run.py:779]
W1008 13:59:03.693000 139990296390528 torch/distributed/run.py:779] *****************************************
W1008 13:59:03.693000 139990296390528 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1008 13:59:03.693000 139990296390528 torch/distributed/run.py:779] *****************************************
[2024-10-08 13:59:08,829] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-10-08 13:59:09,185] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-10-08 13:59:10,178] [INFO] [comm.py:652:init_distributed] cdb=None
[2024-10-08 13:59:10,509] [INFO] [comm.py:652:init_distributed] cdb=None
[2024-10-08 13:59:10,509] [INFO] [comm.py:683:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:37<00:00,  1.96s/it]
/mnt/shared/fujita/trl/.venv/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': max_length. Will not be supported from version '1.0.0'.

Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
  warnings.warn(message, FutureWarning)
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:37<00:00,  1.98s/it]
/mnt/shared/fujita/trl/.venv/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': max_length. Will not be supported from version '1.0.0'.

Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
  warnings.warn(message, FutureWarning)
/mnt/shared/fujita/trl/trl/trainer/dpo_trainer.py:675: UserWarning: You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.
  warnings.warn(
/mnt/shared/fujita/trl/trl/trainer/dpo_trainer.py:693: UserWarning: `max_prompt_length` is not set in the DPOConfig's init it will default to `128` by default, but you should do it yourself in the future.
  warnings.warn(
/mnt/shared/fujita/trl/trl/trainer/dpo_trainer.py:728: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments we have set it for you, but you should do it yourself in the future.
  warnings.warn(
/mnt/shared/fujita/trl/trl/trainer/dpo_trainer.py:675: UserWarning: You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.
  warnings.warn(
/mnt/shared/fujita/trl/trl/trainer/dpo_trainer.py:693: UserWarning: `max_prompt_length` is not set in the DPOConfig's init it will default to `128` by default, but you should do it yourself in the future.
  warnings.warn(
/mnt/shared/fujita/trl/trl/trainer/dpo_trainer.py:728: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments we have set it for you, but you should do it yourself in the future.
  warnings.warn(
Tokenizing train dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 143.21 examples/s]
Tokenizing eval dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 469.79 examples/s]
max_steps is given, it will override any value given in num_train_epochs
Tokenizing train dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 105.89 examples/s]
Tokenizing eval dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 449.17 examples/s]
max_steps is given, it will override any value given in num_train_epochs
/mnt/shared/fujita/trl/.venv/lib/python3.11/site-packages/bitsandbytes/nn/modules.py:452: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.
  warnings.warn(
/mnt/shared/fujita/trl/.venv/lib/python3.11/site-packages/bitsandbytes/nn/modules.py:452: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.
  warnings.warn(
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
Could not estimate the number of tokens of the input, floating-point operations will not be computed
Could not estimate the number of tokens of the input, floating-point operations will not be computed
{'loss': 0.6928, 'grad_norm': 1.3213472366333008, 'learning_rate': 4.5e-05, 'rewards/chosen': -0.00011081695993198082, 'rewards/rejected': -0.0001672744838288054, 'rewards/accuracies': 1.0, 'rewards/margins': 5.645752389682457e-05, 'logps/rejected': -23.34796714782715, 'logps/chosen': -12.41358757019043, 'logits/rejected': 3.972805976867676, 'logits/chosen': 3.841728687286377, 'epoch': 1.0}
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
{'loss': 0.5622, 'grad_norm': 1.125776767730713, 'learning_rate': 4e-05, 'rewards/chosen': 0.22521057724952698, 'rewards/rejected': -0.0760623961687088, 'rewards/accuracies': 1.0, 'rewards/margins': 0.301272988319397, 'logps/rejected': -24.106918334960938, 'logps/chosen': -10.16037368774414, 'logits/rejected': 3.8778529167175293, 'logits/chosen': 3.7467005252838135, 'epoch': 2.0}
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
{'loss': 0.4342, 'grad_norm': 1.044356107711792, 'learning_rate': 3.5e-05, 'rewards/chosen': 0.3974466323852539, 'rewards/rejected': -0.1938224881887436, 'rewards/accuracies': 1.0, 'rewards/margins': 0.5912691354751587, 'logps/rejected': -25.28451919555664, 'logps/chosen': -8.438013076782227, 'logits/rejected': 3.7590129375457764, 'logits/chosen': 3.6288115978240967, 'epoch': 3.0}
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
{'loss': 0.3392, 'grad_norm': 1.0719728469848633, 'learning_rate': 3e-05, 'rewards/chosen': 0.46517130732536316, 'rewards/rejected': -0.6122890710830688, 'rewards/accuracies': 1.0, 'rewards/margins': 1.0774604082107544, 'logps/rejected': -28.70645523071289, 'logps/chosen': -6.773166179656982, 'logits/rejected': 3.015146255493164, 'logits/chosen': 3.35567569732666, 'epoch': 4.0}
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
{'loss': 0.2447, 'grad_norm': 0.8891029357910156, 'learning_rate': 2.5e-05, 'rewards/chosen': 0.5121074318885803, 'rewards/rejected': -0.938281238079071, 'rewards/accuracies': 1.0, 'rewards/margins': 1.4503886699676514, 'logps/rejected': -31.96637725830078, 'logps/chosen': -6.303804874420166, 'logits/rejected': 2.8517861366271973, 'logits/chosen': 3.225398063659668, 'epoch': 5.0}
type(model): <class 'peft.peft_model.PeftModelForCausalLM'>
aux_loss_coef 0.123
type(model): <class 'peft.peft_model.PeftModelForCausalLM'>
aux_loss_coef 0.123
{'eval_loss': 0.6902670860290527, 'eval_runtime': 1.0927, 'eval_samples_per_second': 1.83, 'eval_steps_per_second': 0.915, 'eval_rewards/chosen': 0.330005943775177, 'eval_rewards/rejected': -0.2509513795375824, 'eval_rewards/accuracies': 1.0, 'eval_rewards/margins': 0.580957293510437, 'eval_logps/rejected': -25.502666473388672, 'eval_logps/chosen': -8.488536834716797, 'eval_logits/rejected': 3.206810235977173, 'eval_logits/chosen': 2.9757118225097656, 'epoch': 5.0}
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
{'loss': 0.182, 'grad_norm': 0.7440688014030457, 'learning_rate': 2e-05, 'rewards/chosen': 0.549546480178833, 'rewards/rejected': -1.2760757207870483, 'rewards/accuracies': 1.0, 'rewards/margins': 1.8256222009658813, 'logps/rejected': -35.344322204589844, 'logps/chosen': -5.92941427230835, 'logits/rejected': 2.746168851852417, 'logits/chosen': 3.098029851913452, 'epoch': 6.0}
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
{'loss': 0.1477, 'grad_norm': 0.7196225523948669, 'learning_rate': 1.5e-05, 'rewards/chosen': 0.6175909042358398, 'rewards/rejected': -1.0372135639190674, 'rewards/accuracies': 1.0, 'rewards/margins': 1.6548044681549072, 'logps/rejected': -33.71842956542969, 'logps/chosen': -6.236570358276367, 'logits/rejected': 3.191372871398926, 'logits/chosen': 2.986262798309326, 'epoch': 7.0}
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
{'loss': 0.11, 'grad_norm': 0.6695989370346069, 'learning_rate': 1e-05, 'rewards/chosen': 0.6106638312339783, 'rewards/rejected': -1.3588154315948486, 'rewards/accuracies': 1.0, 'rewards/margins': 1.9694793224334717, 'logps/rejected': -36.9344482421875, 'logps/chosen': -6.305841445922852, 'logits/rejected': 3.0163962841033936, 'logits/chosen': 2.7691760063171387, 'epoch': 8.0}
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
{'loss': 0.0875, 'grad_norm': 0.6486740112304688, 'learning_rate': 5e-06, 'rewards/chosen': 0.4466918110847473, 'rewards/rejected': -1.7334327697753906, 'rewards/accuracies': 1.0, 'rewards/margins': 2.180124521255493, 'logps/rejected': -40.68062210083008, 'logps/chosen': -7.945561408996582, 'logits/rejected': 2.838531494140625, 'logits/chosen': 2.538378953933716, 'epoch': 9.0}
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
type(model): <class 'deepspeed.runtime.engine.DeepSpeedEngine'>
aux_loss_coef 0.0
{'loss': 0.0922, 'grad_norm': 2.0433409214019775, 'learning_rate': 0.0, 'rewards/chosen': 0.3331775665283203, 'rewards/rejected': -1.9678242206573486, 'rewards/accuracies': 1.0, 'rewards/margins': 2.301001787185669, 'logps/rejected': -43.0245361328125, 'logps/chosen': -9.080703735351562, 'logits/rejected': 2.7528159618377686, 'logits/chosen': 2.37682843208313, 'epoch': 10.0}
type(model): <class 'peft.peft_model.PeftModelForCausalLM'>
aux_loss_coef 0.123
type(model): <class 'peft.peft_model.PeftModelForCausalLM'>
aux_loss_coef 0.123
{'eval_loss': 0.7325631976127625, 'eval_runtime': 1.1009, 'eval_samples_per_second': 1.817, 'eval_steps_per_second': 0.908, 'eval_rewards/chosen': -0.6046479344367981, 'eval_rewards/rejected': -0.8562896847724915, 'eval_rewards/accuracies': 1.0, 'eval_rewards/margins': 0.25164175033569336, 'eval_logps/rejected': -31.556049346923828, 'eval_logps/chosen': -17.83507537841797, 'eval_logits/rejected': 2.8031225204467773, 'eval_logits/chosen': 2.527698516845703, 'epoch': 10.0}
{'train_runtime': 42.0997, 'train_samples_per_second': 0.475, 'train_steps_per_second': 0.238, 'train_loss': 0.2892607443034649, 'epoch': 10.0}

Expected behavior

The value specified as model.config.router_aux_loss_coef should be used during both training and evaluation.

As a fix, I think it is good to store the value of model.config.router_aux_loss_coef in DPOTrainer.__init__ just like model.config.output_router_logits: https://github.com/huggingface/trl/blob/adf58d80d01435516fdedf58d255c7dcf009fec4/trl/trainer/dpo_trainer.py#L812. I can send a PR.

It is possible that other trainers have the same issue, but I have not checked.

qgallouedec commented 1 month ago

Thanks for this detailed report.

I can see that you're using deepspeed. Can you share the version?

The suggested solution sounds reasonable.

I can send a PR.

I'd be happy to review it, thanks a lot!

qgallouedec commented 1 month ago

It is possible that other trainers have the same issue, but I have not checked.

I've just checked, we've the same problem with BCO, CPO, KTO and ORPO. Do you mind adding the same fix for those? The codebase is almost the same

muupan commented 1 month ago

I can see that you're using deepspeed. Can you share the version?

I use deepspeed==0.15.1

I'll send a PR shortly. If my PR to DPOTrainer is ok I can address other trainers as well.