huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.63k stars 26.2k forks source link

Attention dropout causing problem in attention score distribution #31468

Open RicRicci22 opened 2 months ago

RicRicci22 commented 2 months ago

System Info

Transformers version 4.41.2 Platform: Ubuntu 22.04.4 LTS Python: 3.10.14

Who can help?

@younesbelkada @ArthurZucker

Information

Tasks

Reproduction

When using GPT-2, during training there is a probability of dropout over the attention scores calculated in each transformer's layer. The dropout acts on the attention scores calculated using K@Q. The problem is that the dropout also normalizes the scores using the inverse of the probability as from the docs

Furthermore, the outputs are scaled by a factor of ​1/(1-p) during training.

This will make it so that the sum of the elements on each row does not necessarily sum to 1 (something that is true before because of the softmax operation). This per se I think is not a major problem, but the fact is that during inference the dropout is inhibited, and thus each line in the attention matrix sums to one, making it so that during inference the model is always slightly out-of-distribution.

Do you think we need another normalization after the dropout? I will put an example script here, to show the behavior when the module is in training or in evaluation

import torch
import torch.nn as nn
from transformers import GPT2Config

class GPT2Attention(nn.Module):
    def __init__(self, config, is_cross_attention=False, layer_idx=None):
        super().__init__()
        self.config = config
        max_positions = config.max_position_embeddings
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                1, 1, max_positions, max_positions
            ),
            persistent=False,
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )

        self.scale_attn_weights = config.scale_attn_weights
        self.is_cross_attention = is_cross_attention

        # Layer-wise attention scaling, reordering, and upcasting
        self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
        self.layer_idx = layer_idx
        self.reorder_and_upcast_attn = config.reorder_and_upcast_attn

        # if self.is_cross_attention:
        #     self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
        #     self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
        # else:
        #     self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
        # self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.is_causal = True

        self.pruned_heads = set()

    # def prune_heads(self, heads):
    #     if len(heads) == 0:
    #         return
    #     heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
    #     index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])

    #     # Prune conv1d layers
    #     self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
    #     self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

    #     # Update hyper params
    #     self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
    #     self.num_heads = self.num_heads - len(heads)
    #     self.pruned_heads = self.pruned_heads.union(heads)

    def _attn(self, query, key, value, attention_mask=None, head_mask=None, correct_normalization=False):
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.scale_attn_weights:
            attn_weights = attn_weights / torch.full(
                [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
            )

        # Layer-wise attention scaling
        if self.scale_attn_by_inverse_layer_idx:
            attn_weights = attn_weights / float(self.layer_idx + 1)

        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            mask_value = torch.finfo(attn_weights.dtype).min
            # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
            # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
            mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
            attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)
        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        if correct_normalization:
            attn_weights = attn_weights/(torch.sum(attn_weights, dim=-1, keepdim=True)+1e-10)

        print(attn_weights[0,0,0,:])
        print(torch.sum(attn_weights[0,0,0,:]))
        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights

def test_behavior(train=True, correct_normalization=False):
    config = GPT2Config()

    attention_layer = GPT2Attention(config)
    if not train:
        attention_layer.eval()

    dummy_query = torch.rand((2,12,8,8))
    dummy_key = torch.rand((2,12,8,8))
    dummy_value = torch.rand((2,12,8,8))
    _ = attention_layer._attn(dummy_query, dummy_key, dummy_value, correct_normalization=correct_normalization)

if __name__=="__main__":
    test_behavior(train=False, correct_normalization=False)

Here you can change the argument "train" to True or False to simulate training or evaluation modes. I also put a correction that you can switch on and off with the argument "correct_normalization".

You can see that when in training, without the correction, the sum is never 1, while if not in training, the sum is 1. This is what I was referring to when I was stating that in inference the model is always slightly out of distribution.

In that case, I'm using the eager implementation of attention, I don't know if this behavior is the same when using flashattention.

Looking forward to hearing from you.

Expected behavior

The expected behavior is that during both train and evaluation, each line in the attention matrix sum to 1 (even when using attention dropout).

ArthurZucker commented 1 month ago

hey! sorry not sure to understand, the dropout should not work with model.eval() no?

RicRicci22 commented 1 month ago

Hi @ArthurZucker! Yes, this is correct, the thing that I was pointing out is that during training the dropout acting on the attention weights is also modifying them using the inverse of the dropout probability. On a standard layer, I understand that this makes it so that the output is similar in magnitude during training and evaluation.

However, in an attention layer, this causes the attention weights to not sum to 1. In turn, during inference, since the dropout is ineffective, the attention weights do sum to 1, and thus there is this discrepancy between train and test that I think can cause some troubles.

It is like the network is always making inferences on slightly out-of-distribution samples.

Not sure if I explained it better now!

ArthurZucker commented 1 month ago

That's for sure, but all models are trained that way 😄 I never thought about this, but dropout in general would be bad for inference, feel free to do some benchmarks I am curious!

dhruvbird commented 1 month ago

Yes, I have also noticed this problem and put together a notebook to demonstrate what is happening. https://colab.research.google.com/drive/10f5pqC4XO5grmP1soT-Yh12-JOFg_i3w?usp=sharing

Due to dropout's behaviour in training, it will scale up the softmax outputs. This causes probabilities to be less than or greater than 1.0 and not exactly 1.0 during training, whereas at test/inference time, this behaviour of dropout is not seen because dropout becomes a no-op during inference, and all probabilities add up to 1.0. I think this might be a problem, but I haven't seen it addressed systematically anywhere. This is not new and has been discussed before on the PyTorch forums: https://github.com/pytorch/pytorch/issues/42929

The way I'd solve this is apply dropout before running softmax so that after softmax, the probabilities add up to 1.0.

ArthurZucker commented 1 month ago

TBH if you want you can open a PR to see if this improves performances of let's say Llama3 on MMLU for example! That would be relevant to say wether or not this has potential impact!

dhruvbird commented 1 month ago

TBH if you want you can open a PR to see if this improves performances of let's say Llama3 on MMLU for example! That would be relevant to say wether or not this has potential impact!

Just changing the code and running inference probably won't help and will most likely make things worse since the model was trained in a specific way and inference should try to keep that the same. In my mind the only? way to actually test this theory is to train 2 models and compare them on specific benchmarks. I lack the GPU resources to do so though. @ArthurZucker if you have some resources I'm happy to send a PR and you could help me validate this?