lucidrains / rotary-embedding-torch

Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch
MIT License
535 stars 43 forks source link

LieRE: Generalizing Rotary Position Encodings. Beats RoPE-mixed by large margin and is much faster (compute-wise) #26

Open kabachuha opened 3 months ago

kabachuha commented 3 months ago

Hi, @lucidrains !

There was a promising research published this month (vs. RoPE-mixed (#25) in March), the so-called LieRE positional encodings generalize the kv-vector rotation to any numbers of dimension (1D, 2D, 3D, etc....), and are much simpler than RoPE in formulation. More than that, they result in much better models accuracy and 25%+ faster training than either axial RoPE or RoPE-mixed. I think their paper was really underappreciated, and this approach will be revolutionary.

LieRE leads to marked improvements in performance (up to 6%), training efficiency (3.5x reduction), data efficiency (30%) compared to the baselines of RoFormer, DeiT III, RoPE-Mixed and Vision-Llama.

The paper is here https://arxiv.org/abs/2406.10322. The LieRE authors gave only the pseudocode for now, however it looks extremely simple.

It looks easy, but I'm a bit confused how to implement the block-diagonal skew matrix with minimal learnable components and structure preservation. (stack of n x 1D parameters + tril_indices + block matrix?) Also integrating block-sparse optimizations for fast rotations would be nice to have

lucidrains commented 3 months ago

@kabachuha new rotary embeddings research! thank you for this, will check it out!

kabachuha commented 3 months ago

My current understanding:

import torch

def flat_to_skew(x, liere_block_size, axes_length, spacial_dims):
    A = torch.zeros(liere_block_size, liere_block_size, axes_length, spacial_dims).to(x.device)
    for d in range(spacial_dims):
        i, j = torch.tril_indices(liere_block_size, liere_block_size, offset=-1)  # w/o diagonal
        A[i, j, :, d] = x[:, :, d]
        A[j, i, :, d] = -x[:, :, d]  # skew
    return A

class AttentionLiereRotator(torch.nn.Module):
    def __init__(self, head_dim, liere_block_size, spacial_dims, axes_length, num_heads):
        super().__init__()
        assert head_dim % liere_block_size == 0 and liere_block_size <= head_dim
        self.liere_block_size = liere_block_size
        self.head_dim = head_dim
        self.spacial_dims = spacial_dims
        self.axes_length = axes_length
        self.num_heads = num_heads

        # trainable parameters (for skew matrices)
        self.vars = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.randn([(liere_block_size*liere_block_size - liere_block_size)//2, axes_length, spacial_dims])) for _ in range(head_dim // liere_block_size)]
        )

        self.spacial_indices = torch.arange(0, axes_length).unsqueeze(1).repeat([1, self.spacial_dims])

    def forward(self, x: torch.Tensor, matrices=None):
        # x [B, X*Y*... (spacial dims), num_heads, head_dim (N * liere_block_size)]
        x = x.view(*[x.shape[0], self.axes_length*self.spacial_dims, self.num_heads, self.head_dim]) # we need only the head for the matrix product

        if matrices is None:

            # precomputed matrices for easier computation
            # the matrix product compute dimensions w/o batch are:
            # if p = spacial_dims*axes_length; dim = head_dim // liere_block_size
            # then result = [p, dim] * exp{[dim,dim,p]*[p]}
            # -- from dimension reduction logic
            matrices = [
                flat_to_skew(v, self.liere_block_size, self.axes_length, self.spacial_dims).view(self.liere_block_size,self.liere_block_size,self.axes_length*self.spacial_dims) @ \
                self.spacial_indices.to(x.device, dtype=x.dtype).view(self.spacial_dims*self.axes_length) for v in self.vars
            ]

            # skew to rotation via exponent
            matrices = [torch.linalg.matrix_exp(A.float()) for A in matrices]
            # -- Fact: Matrix exponent of block diagonal matrix is also block diagonal consisting of matrix exponents of the blocks
            # -- source https://math.stackexchange.com/questions/3836462/matrix-exponential-of-a-block-diagonal-matrix
            # -- TODO: make it work with lower than fp32 precision (if possible in torch)

            # stacking as bigger block diagonal matrix (returning to head_dim x head_dim), then sparsing
            matrices = torch.block_diag(*matrices)

            # batch
            matrices = matrices.unsqueeze(0).repeat(x.shape[0],1,1)

            # to sparse
            matrices = matrices.to_sparse()

        # rotating the vector through multiplication
        # -- making head_dim first
        x = x.permute(0, 3, 1, 2)

        # NOTE: have to upcast x too because of `"bmm_sparse_cuda" not implemented for 'Half'`
        with torch.autocast(device_type=str(x.device).split(':')[0] if not str(x.device).startswith('cpu') else 'cpu',enabled=False):
            dtype_store = x.dtype
            x = torch.bmm(matrices.float(), x.view(*[x.shape[0], self.head_dim, self.axes_length*self.spacial_dims*self.num_heads]).float())
            x = x.view(*[x.shape[0], self.head_dim, self.axes_length*self.spacial_dims, self.num_heads]).permute(0, 2, 3, 1).to(dtype_store)

        return x, matrices

UPD: fixed some formatting/reference mistakes and enabled matrix caching for using the same matrix for k/q rotation

UPD 2: made all matrix operations work

UPD 3: made it launchable with 1D LLaMA text generation

kabachuha commented 3 months ago

Training on a very toy example of shakespeare with the code above

(looks okayish, maybe it will look better when the model has more params)

image

tasansal commented 2 months ago

@lucidrains were you able to look at the paper?

lucidrains commented 2 months ago

@tasansal no i haven't had the time, will take a look soon

SophieOstmeier commented 1 month ago

Hello! Authors of the paper here. We're excited to see folks trying LieRE out.

Here's a minimal example of how we generated the skew symmetric matrices.

generator_raw_params = nn.Parameter(
    torch.rand(
        input_dimensionality,
        head_dim,
        head_dim,
    ) * 2 * math.pi
)

upper_triangle = (
    torch.triu(generator_raw_params, diagonal=1)
)
skew_bases = upper_triangle - torch.transpose(upper_triangle, -1, -2)
in_basis_positions = (
    positions.reshape(list(positions.shape) + [1] * 2) * skew_bases
)
generator_pos = torch.sum(in_basis_positions, dim=-3)
rotation = torch.matrix_exp(generator_pos.to(dtype=torch.float32)).to(dtype=positions.dtype)

And a longer code snippet (https://github.com/SophieOstmeier/LieRE_implementation):

It's very exciting to see someone using the method! Some notes that you may find interesting: 1) We didn't use a sparse matrix representation and instead used the right tensor shape and broadcasting to get the same effect. The slowest configuration we tried was with block size 2 (GPUs don't like lots of small matrices). That said, using a full, dense matrix never really slowed things down a measurable amount as the runtime was dominated by the quadratic component of the attention in our experiments. 2) We used the backbone in https://github.com/kentaroy47/vision-transformers-cifar10 for our experiments. We have noticed that the baseline numbers are slightly different when using the default configuration of x-transformers vs. vision-transformers-cifar10 (cls token, no patch norm, more ff dropout). 3) The performance comparisons were for training time to hit a fixed accuracy. LieRE hits the same accuracies faster than the other methods. Inference time for the same-sized model should be about the same. 4) We saw that LieRE helps more for larger models on CIFAR 100 (model size sweeps are expensive, so we didn't get a chance to sweep model sizes on the larger datasets). We hope to update the arxiv version with those experiments soon. 5) The noncommutativity means that LieRE is able to encode both absolute position information and relative position information. How are you breaking up the text into batch elements?

Doraemonzzz commented 3 weeks ago

Hi, great job. I am the author of Lrpe. I would like to ask how the author perceives the differences between LieRE and Lrpe. Let me briefly explain Lrpe here: Lrpe points out that the decomposable multiplicative relative position encoding $W_i$ is a unitary matrix, and derives that $W_t =P \Lambda^t P^{\mathbf{H}}$, where $P$ is a unitary matrix, $\Lambda_t$ is a diagonal matrix, and ${\mathbf{H}}$ is the conjugate transpose. Under this condition, we have:

$$ q_s^{\mathbf{H}} W_s^{\mathbf{H}} W_tk_t= q_s^{\mathbf{H}} P^{\mathbf{H}} \Lambda^{-s}PP^{\mathbf{H}} \Lambda^{t} P k_t = (q_sP \Lambda^s)^{\mathbf{H}}(k_tP \Lambda^t)^{\mathbf{H}}. $$

More details can be found in the paper. It would be greatly appreciated if we could discuss the performance and theoretical differences between LieRE and Lrpe.

Doraemonzzz commented 3 weeks ago

@lucidrains Hi lucidrains, if you find Lrpe valuable, I can submit a pull request.