huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
8.61k stars 1.06k forks source link

MoE Models: option to add load balancing loss #1765

Closed claralp closed 6 days ago

claralp commented 1 week ago

This fixes #1544
It optionally enables adding the load balancing loss (also aux_loss) of Mixture of Experts models in DPO, KTO, CPO or ORPO. We are using this for a while now and in our experiments, the models that were fine-tuned using this option perform best at the moment. Maybe that applies to other project as well, so I wanted to open-source the idea and leave it to the users to enable it or not.

This option is simply enabled by setting output_router_logits=True in the MixtralConfig and optionally scale it with router_aux_loss_coef =...(default is 0.001)

@kashif @lewtun

kashif commented 1 week ago

thanks @claralp could you kindly also add something in the docs too?

HuggingFaceDocBuilderDev commented 1 week ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

claralp commented 6 days ago

@kashif do you know why tests / tests (3.8, windows-latest) (pull_request) is failing ? does not seems to be related to any of the changes in this PR

PhilipMay commented 6 days ago

@kashif do you know why tests / tests (3.8, windows-latest) (pull_request) is failing ? does not seems to be related to any of the changes in this PR

It is a HTTP timeout:

E   urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='cdn-lfs.huggingface.co', port=443): Read timed out. (read timeout=10)

For me this smells like a flaky test.

kashif commented 6 days ago

its a flakey teest i believe too