meta-llama / llama-recipes

Scripts for fine-tuning Meta Llama3 with composable FSDP & PEFT methods to cover single/multi-node GPUs. Supports default & custom datasets for applications such as summarization and Q&A. Supporting a number of candid inference solutions such as HF TGI, VLLM for local or cloud deployment. Demo apps to showcase Meta Llama3 for WhatsApp & Messenger.
11.52k stars 1.63k forks source link

Masking loss from labels associated with user turns #209

Closed jveronvialard closed 2 weeks ago

jveronvialard commented 11 months ago

In Concatenator.__call__() https://github.com/facebookresearch/llama-recipes/blob/main/src/llama_recipes/datasets/utils.py#L38 labels are created by simply copying the input_ids. Shouldn't the loss from the labels associated with the user turns be masked too? See for example https://github.com/bigcode-project/starcoder/blob/main/chat/dialogues.py#L232 and https://huggingface.co/blog/starchat-alpha#masking-user-labels.

mreso commented 11 months ago

Good catch, created a PR

jveronvialard commented 11 months ago

Nice, I took a look at the PR https://github.com/facebookresearch/llama-recipes/pull/211, LGTM