meta-llama / llama

Inference code for Llama models
Other
56.23k stars 9.55k forks source link

Logits for all positions? #294

Closed mawilson1234 closed 1 year ago

mawilson1234 commented 1 year ago

In model.Transformer.forward, the following line says it'll only compute the logits for the last position in h:

output = self.output(h[:, -1, :])  # only compute last logits

I'm interested in getting surprisal values for each word in a sentence, so I'd like logits for every position.

It looks like first, I need to fix up the inputs by converting the pad_ids to eos_id, since pad_id is -1, which doesn't have an embedding. In contrast, eos_id is 2, which does have an embedding (though I'm not bothering to examine the logits for it or anything after—it's just to be able to run batches of sentences with unequal lengths).

After I do this, is it as simple as changing the line above to the following to get the logits for each position for each example in the batch? Just want to make sure I'm not missing anything obvious.

output = self.output(h)
ejsd1989 commented 1 year ago

@mawilson1234 I just wanted to check in to see if you were able to come to a resolution to your question?

ra-MANUJ-an commented 1 year ago

@ejsd1989 @mawilson1234 Were any of you guys able to resolve the issue?

mawilson1234 commented 1 year ago

@ra-MANUJ-an Ah, forgot to respond to @ejsd1989's comment—yes, changing the line to output = self.output(h) worked. Thanks for checking! I'll go ahead and close this.