idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

Tips and tricks for training linear_att #84

Closed gaceladri closed 3 years ago

gaceladri commented 3 years ago

Hello,

I have migrated your linear_attention.py to be compatible with huggingface. I also have modified the masking part to do the LenghtMask.

The thing is that the model is very brittle and use to diverge. It is very sensitive to hyper-parameters and initialization.

Do you have some tips and tricks to train the linear_attention?

Thanks!

class LinearAttention(nn.Module):
    """Implement unmasked attention using dot product of feature maps in
    O(N D^2) complexity.
    Given the queries, keys and values as Q, K, V instead of computing
        V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
    we make use of a feature map function Φ(.) and perform the following
    computation
        V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
    The above can be computed in O(N D^2) complexity where D is the
    dimensionality of Q, K and V and N is the sequence length. Depending on the
    feature map, however, the complexity of the attention might be limited.
    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: elu(x)+1)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """

    def __init__(self, config, feature_map=None, eps=1e-4):
        super(LinearAttention, self).__init__()
        self.feature_map = (
            feature_map(config.true_hidden_size) if feature_map else
            elu_feature_map(config.true_hidden_size)
        )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.eps = eps
        self.query_projection = nn.Linear(config.true_hidden_size, self.all_head_size)
        self.key_projection = nn.Linear(config.true_hidden_size, self.all_head_size)
        self.value_projection = nn.Linear(config.true_hidden_size if config.use_bottleneck_attention else config.hidden_size, self.all_head_size)
        self.out_projection = nn.Linear(config.true_hidden_size, config.true_hidden_size)
        self.n_heads = config.num_attention_heads

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):

        N, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        # Project the queries/keys/values
        queries = self.query_projection(queries).view(N, L, H, -1)
        keys = self.key_projection(keys).view(N, S, H, -1)
        values = self.value_projection(values).view(N, S, H, -1)

        # Apply the feature map to the queries and keys
        # self.feature_map.new_feature_map(queries.device)
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)

        # Apply the key padding mask and make sure that the attn_mask is
        # all_ones
        if not attn_mask.all_ones:
            raise RuntimeError(("LinearAttention does not support arbitrary "
                                "attention masks"))
        K = K * key_lengths.float_matrix[:, :, None, None]

        # Compute the KV matrix, namely the dot product of keys and values so
        # that we never explicitly compute the attention matrix and thus
        # decrease the complexity
        KV = torch.einsum("nshd,nshm->nhmd", K, values)

        # Compute the normalizer
        Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)

        # Finally compute and return the new values
        V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z).contiguous().view(N, L, -1)

        return self.out_projection(V)
def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_hidden_states=None,
        output_attentions=None,
        return_dict=None,
        output_layers=None,
        regression=False,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        N = input_shape[0]
        L = input_shape[1]
        if input_ids is not None:
            x = input_ids
        elif inputs_embeds is not None:
            x = inputs_embeds
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")
        extended_attention_mask = FullMask(L, device=x.device)
        # extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # ¿?
        head_mask = LengthMask(x.new_full((N,), L, dtype=torch.int64))
angeloskath commented 3 years ago

Hi,

I am not sure I follow your modifications regarding masking. For instance we cannot apply a NxN mask with linear attention because the attention matrix is never computed explicitly. For instance, what is the head mask in this case? Have you checked that given the same weights and the same inputs both HuggingFace linear attention and ours return the same result?

In general there are no special tricks used for training and I have not seen any type of instability. Could you provide more information regarding that? Sequence length, query size and possibly the range of the values of the normalizer Z? As in any transformer you could also use learning rate warmup and gradient clipping but I wouldn't say that I have experience divergence otherwise.

Cheers, Angelos

gaceladri commented 3 years ago

I took the masking that I thought that you do when using linear_att.


N = input_shape[0]
L = input_shape[1]
if input_ids is not None:
    x = input_ids
elif inputs_embeds is not None:
    x = inputs_embeds
else:
    raise ValueError("You have to specify either input_ids or inputs_embeds")
extended_attention_mask = FullMask(L, device=x.device)
# extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # ¿?
head_mask = LengthMask(x.new_full((N,), L, dtype=torch.int64))

I assumed that this should be the head_masking <- your LengthMask class. I took it from somewhere in your code thinking that this should be done like that. It is not the case?

edit: I am modifying this model from huggingface: https://github.com/huggingface/transformers/blob/6e1ee47b361f9dc1b6da0104d77b38297042efae/src/transformers/models/mobilebert/modeling_mobilebert.py#L875 Edit 2: sorry. I copied my code as is. I am going to remove this comments on my first post because those comments are regarding to the original huggingface masking.

angeloskath commented 3 years ago

Possibly :-). Maybe the name got me confused. The masking is either on the attention matrix (which for linear should be all ones) or per sample which I consider it to be the lengths of each sequence, namely how many keys for each sample, batch_size x sequence length.

If head mask is simply passed to key_lengths then it should be fine.

gaceladri commented 3 years ago

By the way, are you aware of any linear att. implementation in huggingface?

gaceladri commented 3 years ago

Definitely I have something weird: Screenshot from 2021-04-15 16-29-57 I modified my linear att with fixup to remove layernorm. I need to check what is happening...

gaceladri commented 3 years ago

I think that I have to fix fixup xD Screenshot from 2021-04-15 19-28-10 Now it is more consistent but it is getting behind wrt. the original mobilebert. I am using selu instead of elu. I will try with some of them. Thanks for your support. Also if you'r aware of any other implementation to check please let me know.

Closing as it is not an issue.

Best. Adrian

edit: with celu seems much better Screenshot from 2021-04-15 20-19-53

angeloskath commented 3 years ago

Good to know! :-)

Regarding the falling behind, that could simply be due to linear. Since the attention matrix is now low rank, learning is bound to be harder. The whole point is about performance/wall-clock time tradeoffs.

If your sequence is large then linear is going to be significantly faster, and since the performance difference in 20k steps is minuscule, linear is a better choice.

(I was replying when I saw your edit) The fact that celu works better could be interesting, however, if linear is not faster softmax then there is little point in using it. So the bottom line is unless you care about speed or memory, your are probably better off with softmax. If you do care however (e.g. you are processing sequences 10k long or you only have an rtx 2060 or ...) then a small performance drop could be expected. You can always increase the number of layers or heads to compensate.

Let me know if I can help in any way.

Cheers, Angelos