lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.62k stars 396 forks source link

[Question!] How to Inject Rotary Positional Embeddings in Linear Transformers #34

Open gaceladri opened 3 years ago

gaceladri commented 3 years ago

Hello Phil,

Do you mind how to inject the rotary positional embeddings into the linear transformers ?

import torch
from torch.nn import Module

from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
    EventDispatcherInstance
from ..events import EventDispatcher
from ..feature_maps import elu_feature_map

class LinearAttention(Module):
    """Implement unmasked attention using dot product of feature maps in
    O(N D^2) complexity.
    Given the queries, keys and values as Q, K, V instead of computing
        V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
    we make use of a feature map function Φ(.) and perform the following
    computation
        V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
    The above can be computed in O(N D^2) complexity where D is the
    dimensionality of Q, K and V and N is the sequence length. Depending on the
    feature map, however, the complexity of the attention might be limited.
    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: elu(x)+1)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """
    def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
                 event_dispatcher=""):
        super(LinearAttention, self).__init__()
        self.feature_map = (
            feature_map(query_dimensions) if feature_map else
            elu_feature_map(query_dimensions)
        )
        self.eps = eps
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map(queries.device)
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)

        # Apply the key padding mask and make sure that the attn_mask is
        # all_ones
        if not attn_mask.all_ones:
            raise RuntimeError(("LinearAttention does not support arbitrary "
                                "attention masks"))
        K = K * key_lengths.float_matrix[:, :, None, None]

        # Compute the KV matrix, namely the dot product of keys and values so
        # that we never explicitly compute the attention matrix and thus
        # decrease the complexity
        KV = torch.einsum("nshd,nshm->nhmd", K, values)

        # Compute the normalizer
        Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)

        # Finally compute and return the new values
        V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)

        return V.contiguous()

Thanks!

lucidrains commented 3 years ago

@gaceladri Hi again! I have it built into Performers, you can try it out there and use it as an example https://github.com/lucidrains/performer-pytorch

lucidrains commented 3 years ago

you would simply apply the rotary embeddings right after

Q = self.feature_map.forward_queries(queries)
K = self.feature_map.forward_keys(keys)

Q, K = apply_rot_emb(Q, K, sinu_emb)

the sinusoidal embeddings must be calculated initially then passed to each attention block

gaceladri commented 3 years ago

Amazing! You are amazing! Thanks a lot, I will try it!!

lucidrains commented 3 years ago

@gaceladri make sure to turn off absolute positional embeddings when you try it! it conflicts with rotary for some unknown reason - more research needed

lucidrains commented 3 years ago

@gaceladri i had trouble making rotary work well with the linear attention in https://github.com/lucidrains/linear-attention , but i suspect its because i'm using the softmax kernel there. it should work well with the elu kernel :crossed_fingers: