meta-llama / llama3

The official Meta Llama 3 GitHub site
Other
23.86k stars 2.59k forks source link

Getting attention weights for generated text from llama3-8b-instruct #155

Open bear96 opened 2 months ago

bear96 commented 2 months ago

Hello,

I'm trying to visualize the attention weights for Llama 3 when it generates text, but I am facing some complications. I slightly modified the Attention class to output the scores variable (which I am guessing is the attention weights, since it is multiplied to produce the attention outputs), and then I save the attention weight values in the TransformerBlock class as an attribute. I also modified this step in the forward function h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) to

y, self.weights = self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
h = x + y

where self.weights is attention weights.

Now, in generation.py, in generate function, starting from line 175, I modify the for loop in this way:

attention_dict = dict()
for cur_pos in range(min_prompt_len, total_len):
        logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
        # taking the last transformer block
        attention_dict[cur_pos] = self.model.layers[-1].weights.float().cpu().numpy()

        if temperature > 0:
            probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
            next_token = sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits[:, -1], dim=-1)

        next_token = next_token.reshape(-1)
        # only replace token if prompt has already been generated
        next_token = torch.where(
            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
        )
        tokens[:, cur_pos] = next_token
        if logprobs:
            token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
                input=logits.transpose(1, 2),
                target=tokens[:, prev_pos + 1 : cur_pos + 1],
                reduction="none",
                ignore_index=pad_id,
            )
        eos_reached |= (~input_text_mask[:, cur_pos]) & (
            torch.isin(next_token, stop_tokens)
        )
        prev_pos = cur_pos
        if all(eos_reached):
            break

The idea is to get the attention weights from the last transformer block for each step of token generation, so that I can go back to any generated token and see how the attention weights are distributed along the generated sequence length.

However, the problem that I am facing is, that there are 32 heads for Llama3. If I average across all 32 heads to reduce dimensionality for visualization purposes, and then apply Softmax to the output, the attention weights that I am getting for any step of generation has the exact same distribution, with the first few tokens, and the last two tokens having a much higher value, while every other token has the exact same weight (which is very miniscule). I have also tried max pooling across heads instead of averaging, but it yielded similar results.

My point is that this doesn't seem right, because it also remains the same across different prompts, which means there is something wrong with my approach. Could you please guide me in the right direction?

Thanks!

Icamd commented 2 months ago

Hi, I am also trying to output the attention weight of llama3. Have you tried output Llama3 attention weight with itself? (for example, outputs = model.generate(tokens, output_attentions=True))

bear96 commented 2 months ago

Hi, I am also trying to output the attention weight of llama3. Have you tried output Llama3 attention weight with itself? (for example, outputs = model.generate(tokens, output_attentions=True))

Hi @Icamd, I think you're using the Huggingface version? I have tried using the same thing you have, but the attention weights I get are of a strange shape. Usually, attention weights have the shape (batch_size, num_heads, seq_length, seq_length), but in Huggingface Llama's case, I get a mismatch in the batch_size axis. It is my guess that since output_attentions is not actually a parameter in the model architecture shown in this repo, Huggingface does something internally to calculate the attention weights, and thus provides wrong values. I could be wrong, of course. I also get some warnings whenever I have tried to do this with Huggingface. That's why I am using the PyTorch version of this model instead.

Icamd commented 2 months ago

Hi, I am also trying to output the attention weight of llama3. Have you tried output Llama3 attention weight with itself? (for example, outputs = model.generate(tokens, output_attentions=True))

Hi @Icamd, I think you're using the Huggingface version? I have tried using the same thing you have, but the attention weights I get are of a strange shape. Usually, attention weights have the shape (batch_size, num_heads, seq_length, seq_length), but in Huggingface Llama's case, I get a mismatch in the batch_size axis. It is my guess that since output_attentions is not actually a parameter in the model architecture shown in this repo, Huggingface does something internally to calculate the attention weights, and thus provides wrong values. I could be wrong, of course. I also get some warnings whenever I have tried to do this with Huggingface. That's why I am using the PyTorch version of this model instead.

I thinks Huggingface version's attention weight has the shape of (outputs_token_numbe, layers, batch_size, heads, input_token_number, input_token_number), for example (50, 32, 1, 32, 251, 251), but I am not sure. I am still trying to visualize the attention between tokens to find out if there is any connection. However I have strange model outputs using NousResearch/Meta-Llama-3-8B-Instruct in google colab :( image

Icamd commented 2 months ago

@bear96 Hi! I find this paper "Analyzing the Structure of Attention in a Transformer Language Model" mentioned something called " Null Attention", which said "attention focused on the first token". Maybe you can try to mask the first token's attention so it won't influence the overall attention weight?(I'm not sure)

bear96 commented 2 months ago

@bear96 Hi! I find this paper "Analyzing the Structure of Attention in a Transformer Language Model" mentioned something called " Null Attention", which said "attention focused on the first token". Maybe you can try to mask the first token's attention so it won't influence the overall attention weight?(I'm not sure)

I'll definitely check that out! Thanks!

bear96 commented 2 months ago

I believe I have solved the issue. I was taking an average across all 32 heads and then applying a softmax function to get them to appear as probabilities, but that caused a lot of minute changes in the attention weights to disappear, leaving an almost uniform distribution of weights. I'm trying to visualize the attention weights with respect to individual heads instead. Due to Null Attention as cited by @Icamd the first token has extremely high attention weights, whereas the rest of the weights vary in an exponential way, so I am having to take the log of these weights instead for better visualizations.

I am not sure why Null Attention occurs however. If someone knows more about this, please let me know!