Closed muupan closed 1 month ago
Thanks for this detailed report.
I can see that you're using deepspeed. Can you share the version?
The suggested solution sounds reasonable.
I can send a PR.
I'd be happy to review it, thanks a lot!
It is possible that other trainers have the same issue, but I have not checked.
I've just checked, we've the same problem with BCO, CPO, KTO and ORPO. Do you mind adding the same fix for those? The codebase is almost the same
I can see that you're using deepspeed. Can you share the version?
I use deepspeed==0.15.1
I'll send a PR shortly. If my PR to DPOTrainer is ok I can address other trainers as well.
System Info
Python 3.11.7 trl==0.11.2 transformers==4.45.1 accelerate==0.34.2
Information
Tasks
examples
folderReproduction
The documentation says
router_aux_loss_coef
is used as a coefficient for the auxiliary loss.However, it seems
router_aux_loss_coef
is not actually used during training. Instead, the coefficient is always 0, meaning that the auxiliary loss is never used, which I think is a bug as it contradicts the documentation.Here is my code to reproduce the issue. I set
model.config.output_router_logits = True
andmodel.config.router_aux_loss_coef = 0.123
to enable the auxiliary loss.Since there seems no way to know the actual coefficient used in the loss computation, I added a print statement.
When I execute the following command in a machine with two H100s, I can see that the coefficient is 0 during training. The value I specified, 0.123, is used only during evaluation. It seems that the discrepancy comes from the fact that
model
is wrapped byDeepSpeedEngine
during training so itrouter_aux_loss_coef
is not inmodel.config
.Expected behavior
The value specified as
model.config.router_aux_loss_coef
should be used during both training and evaluation.As a fix, I think it is good to store the value of
model.config.router_aux_loss_coef
inDPOTrainer.__init__
just likemodel.config.output_router_logits
: https://github.com/huggingface/trl/blob/adf58d80d01435516fdedf58d255c7dcf009fec4/trl/trainer/dpo_trainer.py#L812. I can send a PR.It is possible that other trainers have the same issue, but I have not checked.