Closed sorenmulli closed 1 year ago
Hi! I'll add a further note on this to the comment and the documentation, as this is a frequent question.
The reason behind chopping off the last completion token is that for autoregressive LLMs, they take in tokens up to position N
and return a logit distribution for position N+1
. Therefore, the logit the model assigns to token N
is obtained by feeding in 0 1 2 3.... (N - 1)
and then taking the last logit position--this is the logit for the Nth
token.
When we're feeding in
0 1 2 3 | 4
what we want is the logits predicting 4
. To get the logits for 4
conditioned on 0 1 2 3
we must feed 0 1 2 3
in without passing in 4. Then, the final logits index is the predicted distribution over tokens at the 4
position, which is what we wanted! The same applies for multi-token continuations.
Leaving open until I update the documentation. If this doesn't make sense happy to clarify further!
For multi-token continuations, do we only drop the last token? if the input is 0 1 2 3 and the continuation is 4 5 6, do we condition on 0 1 2 3 4 5? Thanks
Thank you very much, @haileyschoelkopf for a swift reply! And for a good explanation of the indexing.
IDK why I thought that calling cross entropy loss on the logits would magically handle this for me, this shifting is of course also implemented in decoder model losses. For completeness, I have updated my little code snippet such that it gives the same result
import torch
import torch.nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from lm_eval.models.huggingface import AutoCausalLM
lm_key = "sshleifer/tiny-gpt2"
context = "we are the"
cont = " koala bears of the world"
model = AutoModelForCausalLM.from_pretrained(lm_key)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(lm_key)
encodings = tokenizer(context, text_target=cont, return_tensors="pt")
input_ids = torch.cat((encodings.input_ids, encodings.labels), dim=1)
target_ids = input_ids.clone()
# Makes context ignored by loss function
target_ids[:, : encodings.input_ids.size(1)] = -100
with torch.no_grad():
logits = model(input_ids).logits
# Move vocab dimension last as we do classification over these
logits = logits.permute(0, 2, 1)
# Task: Next-token-prediction => shift tokens
target_ids = target_ids[:, 1:]
logits = logits[:, :, :-1]
losses = torch.nn.CrossEntropyLoss(reduction="none")(logits, target_ids)
print(-losses.sum().item())
# Result: -65.07633972167969
lm_eval_model = AutoCausalLM(lm_key, device="cpu")
print(lm_eval_model.loglikelihood([(context, cont)])[0][0])
# Result: -65.07632446289062
# Same results - yay!
Glad this is helpful!!
For multi-token continuations, do we only drop the last token? if the input is 0 1 2 3 and the continuation is 4 5 6, do we condition on 0 1 2 3 4 5? Thanks
@sasaadi yes, we would feed 0 1 2 3 4 5
into the model, which will then give us out logits of shape (seqlen, vocabsize) = (6, vocabsize)
. The last sequence position of these logits is the logit for the model to predict the 6
position conditioned on up to 5
, and the second-to-last sequence position would give the prediction for 5
conditioned on 0 1 2 3 4
, and so on.
So if the continuation is 4 5 6
, we want:
6
, conditional on 0 1 2 3 4 5
)5
, conditional on 0 1 2 3 4
)4
, conditional on 0 1 2 3
)
and we don't care about how likely the model would be to generate the input/context.And the loglikelihood of the completion is the loglikelihood of producing all 3 completion tokens in turn, starting from 0 1 2 3
. so to get the probability we'd multiply the probs of producing each completion token, or add the log-probabilities of producing each completion token assuming we got the previous ones right.
Question
In https://github.com/EleutherAI/lm-evaluation-harness/blob/3ccea2b2854dd3cc9ff5ef1772e33de21168c305/lm_eval/base.py#L342 (and refactor: https://github.com/EleutherAI/lm-evaluation-harness/blob/408115eaffc4eecc9584f543db573d708eef8ed6/lm_eval/models/huggingface.py#L708), the last input token is dropped before the model call.
This is motivated by this diagram:
I must admit that I do not understand why this is: Does anyone have som pointers as to why removing this yields correct probabilities (surely the value of the last token matters for the overall likelihood?).
Minimal Example
The below computation shows that I can reproduce the result of _loglikelihood_tokens only if I remove the
[:-1]
, otherwise there is a difference from the last token:Similar issues
A similar question was asked in #337 where @jon-tow, who asked the question, closed with the message