EleutherAI / lm-evaluation-harness

A framework for few-shot evaluation of language models.
https://www.eleuther.ai
MIT License
6.94k stars 1.85k forks source link

Why is last token dropped in loglikelihood computation? Gives different result than when calculating loss. #942

Closed sorenmulli closed 1 year ago

sorenmulli commented 1 year ago

Question

In https://github.com/EleutherAI/lm-evaluation-harness/blob/3ccea2b2854dd3cc9ff5ef1772e33de21168c305/lm_eval/base.py#L342 (and refactor: https://github.com/EleutherAI/lm-evaluation-harness/blob/408115eaffc4eecc9584f543db573d708eef8ed6/lm_eval/models/huggingface.py#L708), the last input token is dropped before the model call.

This is motivated by this diagram:

# how this all works:
#          CTX      CONT
# inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
# gpt2    \               \
# logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
# cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

I must admit that I do not understand why this is: Does anyone have som pointers as to why removing this yields correct probabilities (surely the value of the last token matters for the overall likelihood?).

Minimal Example

The below computation shows that I can reproduce the result of _loglikelihood_tokens only if I remove the [:-1], otherwise there is a difference from the last token:

import torch
import torch.nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from lm_eval.models.huggingface import AutoCausalLM

lm_key = "sshleifer/tiny-gpt2"
context = "we are the"
cont = " koala bears of the world"

model = AutoModelForCausalLM.from_pretrained(lm_key)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(lm_key)

encodings = tokenizer(context, text_target=cont, return_tensors="pt")
input_ids = torch.cat((encodings.input_ids, encodings.labels), dim=1)
target_ids = input_ids.clone()
# Makes context ignored by loss function
target_ids[:, : encodings.input_ids.size(1)] = -100

with torch.no_grad():
    logits = model(input_ids).logits

# Move vocab dimension last as we do classification over these
logits = logits.permute(0, 2, 1)
losses = torch.nn.CrossEntropyLoss(reduction="none")(logits, target_ids)
print(-losses.sum().item())
# Result: -64.88402557373047

lm_eval_model = AutoCausalLM(lm_key, device="cpu")
print(lm_eval_model.loglikelihood([(context, cont)])[0][0])
# Result: -65.07632446289062
# If I remove the `[:-1]` in _loglikelihood_tokens:
# -64.88401794433594

Similar issues

A similar question was asked in #337 where @jon-tow, who asked the question, closed with the message

Update: I confused position indexing (next-token distribution)

haileyschoelkopf commented 1 year ago

Hi! I'll add a further note on this to the comment and the documentation, as this is a frequent question.

The reason behind chopping off the last completion token is that for autoregressive LLMs, they take in tokens up to position N and return a logit distribution for position N+1. Therefore, the logit the model assigns to token N is obtained by feeding in 0 1 2 3.... (N - 1) and then taking the last logit position--this is the logit for the Nth token.

When we're feeding in

0 1 2 3 | 4 

what we want is the logits predicting 4. To get the logits for 4 conditioned on 0 1 2 3 we must feed 0 1 2 3 in without passing in 4. Then, the final logits index is the predicted distribution over tokens at the 4 position, which is what we wanted! The same applies for multi-token continuations.

Leaving open until I update the documentation. If this doesn't make sense happy to clarify further!

sasaadi commented 1 year ago

For multi-token continuations, do we only drop the last token? if the input is 0 1 2 3 and the continuation is 4 5 6, do we condition on 0 1 2 3 4 5? Thanks

sorenmulli commented 1 year ago

Thank you very much, @haileyschoelkopf for a swift reply! And for a good explanation of the indexing.

IDK why I thought that calling cross entropy loss on the logits would magically handle this for me, this shifting is of course also implemented in decoder model losses. For completeness, I have updated my little code snippet such that it gives the same result

import torch
import torch.nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from lm_eval.models.huggingface import AutoCausalLM

lm_key = "sshleifer/tiny-gpt2"
context = "we are the"
cont = " koala bears of the world"

model = AutoModelForCausalLM.from_pretrained(lm_key)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(lm_key)

encodings = tokenizer(context, text_target=cont, return_tensors="pt")
input_ids = torch.cat((encodings.input_ids, encodings.labels), dim=1)
target_ids = input_ids.clone()
# Makes context ignored by loss function
target_ids[:, : encodings.input_ids.size(1)] = -100

with torch.no_grad():
    logits = model(input_ids).logits

# Move vocab dimension last as we do classification over these
logits = logits.permute(0, 2, 1)
# Task: Next-token-prediction => shift tokens
target_ids = target_ids[:, 1:]
logits = logits[:, :, :-1]
losses = torch.nn.CrossEntropyLoss(reduction="none")(logits, target_ids)
print(-losses.sum().item())
# Result: -65.07633972167969
lm_eval_model = AutoCausalLM(lm_key, device="cpu")
print(lm_eval_model.loglikelihood([(context, cont)])[0][0])
# Result: -65.07632446289062

# Same results - yay!
haileyschoelkopf commented 1 year ago

Glad this is helpful!!

For multi-token continuations, do we only drop the last token? if the input is 0 1 2 3 and the continuation is 4 5 6, do we condition on 0 1 2 3 4 5? Thanks

@sasaadi yes, we would feed 0 1 2 3 4 5 into the model, which will then give us out logits of shape (seqlen, vocabsize) = (6, vocabsize). The last sequence position of these logits is the logit for the model to predict the 6 position conditioned on up to 5, and the second-to-last sequence position would give the prediction for 5 conditioned on 0 1 2 3 4, and so on.

So if the continuation is 4 5 6, we want:

And the loglikelihood of the completion is the loglikelihood of producing all 3 completion tokens in turn, starting from 0 1 2 3. so to get the probability we'd multiply the probs of producing each completion token, or add the log-probabilities of producing each completion token assuming we got the previous ones right.