meta-llama / llama

Inference code for Llama models
Other
54.64k stars 9.36k forks source link

why the mask hstack in model.py? #938

Open hscspring opened 8 months ago

hscspring commented 8 months ago

Here is the code in model.py (line 482)

# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([
    torch.zeros((seqlen, start_pos), device=tokens.device),
    mask
]).type_as(h)

Except the prompt input, the followed generated tokens are all only one token (seqlen=1). It means this mask operation only used for the first input(with prompt), and so the start_pos is always zero, the hstack operation here actually doesn't do anything.

Is anyone who knows the effect here?

leigao97 commented 3 months ago

I also have the same question. Here is the commit that added this code. @flu0r1ne Can you please explain more on this? Any help would be appreciated.

flu0r1ne commented 3 months ago

It allows the underlying model's KV cache to be maintained between interleaved messages. Prior to this change, the KV cache had to be re-computed between each message (but not within the auto-regressive loop.) The generation code in this repository does not use this property, but it was a bug in the wrapper for the underlying model. If you hook it yourself (as I was doing), you can achieve a speed up. See #899.