huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.85k stars 27.19k forks source link

clarify the label shifting behavior of llama models when `labels` is given. #32944

Open keunwoochoi opened 3 months ago

keunwoochoi commented 3 months ago

Feature request

i believe labels in the training of causal LMs means the value to predict at time n, i.e., the next token. in other words, i'd assume, if labels is given, it should be already shifted by one in the data loader w.r.t. the input_ids.

however, in LlamaForCausalLM.forward(), i found the labels are always shifted, silently.

https://github.com/huggingface/transformers/blob/f1d822ba337499d429f832855622b97d90ac1406/src/transformers/models/llama/modeling_llama.py#L1205-L1210


        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

...


        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

i found it quite unexpected hence calling it "silently". as this is for a causal LM, shouldn't it be not shifting the labels by default? in modeling GPT2, this is at least documented explicitly.

https://github.com/huggingface/transformers/blob/f1d822ba337499d429f832855622b97d90ac1406/src/transformers/models/gpt2/modeling_gpt2.py#L1309-1314

in gemma2, it has the same behavior and no explicit mentioning in the docstring.

https://github.com/huggingface/transformers/blob/f1d822ba337499d429f832855622b97d90ac1406/src/transformers/models/gemma2/modeling_gemma2.py#L978-L982

i think at least we should force the docstring to mention this, if making a change is too dangerous at this point.

Motivation

i didn't expect this behavior and used my data loader, which does the shifting already, as i believe that is what labels should mean. as a result, i ended up finetuning a model to predict the next next token, which outputted gibberish.

Your contribution

NielsRogge commented 3 months ago

Hi,

Sure, feel free to open a PR. Usually, users are expected to make labels a copy of the input_ids, with padding tokens (or other tokens which the model can ignore) replaced by -100. See the example notebook or script here for that.

Feel free to open a PR to clarify this in the docs.