Open kabachuha opened 3 months ago
@kabachuha new rotary embeddings research! thank you for this, will check it out!
My current understanding:
import torch
def flat_to_skew(x, liere_block_size, axes_length, spacial_dims):
A = torch.zeros(liere_block_size, liere_block_size, axes_length, spacial_dims).to(x.device)
for d in range(spacial_dims):
i, j = torch.tril_indices(liere_block_size, liere_block_size, offset=-1) # w/o diagonal
A[i, j, :, d] = x[:, :, d]
A[j, i, :, d] = -x[:, :, d] # skew
return A
class AttentionLiereRotator(torch.nn.Module):
def __init__(self, head_dim, liere_block_size, spacial_dims, axes_length, num_heads):
super().__init__()
assert head_dim % liere_block_size == 0 and liere_block_size <= head_dim
self.liere_block_size = liere_block_size
self.head_dim = head_dim
self.spacial_dims = spacial_dims
self.axes_length = axes_length
self.num_heads = num_heads
# trainable parameters (for skew matrices)
self.vars = torch.nn.ParameterList(
[torch.nn.Parameter(torch.randn([(liere_block_size*liere_block_size - liere_block_size)//2, axes_length, spacial_dims])) for _ in range(head_dim // liere_block_size)]
)
self.spacial_indices = torch.arange(0, axes_length).unsqueeze(1).repeat([1, self.spacial_dims])
def forward(self, x: torch.Tensor, matrices=None):
# x [B, X*Y*... (spacial dims), num_heads, head_dim (N * liere_block_size)]
x = x.view(*[x.shape[0], self.axes_length*self.spacial_dims, self.num_heads, self.head_dim]) # we need only the head for the matrix product
if matrices is None:
# precomputed matrices for easier computation
# the matrix product compute dimensions w/o batch are:
# if p = spacial_dims*axes_length; dim = head_dim // liere_block_size
# then result = [p, dim] * exp{[dim,dim,p]*[p]}
# -- from dimension reduction logic
matrices = [
flat_to_skew(v, self.liere_block_size, self.axes_length, self.spacial_dims).view(self.liere_block_size,self.liere_block_size,self.axes_length*self.spacial_dims) @ \
self.spacial_indices.to(x.device, dtype=x.dtype).view(self.spacial_dims*self.axes_length) for v in self.vars
]
# skew to rotation via exponent
matrices = [torch.linalg.matrix_exp(A.float()) for A in matrices]
# -- Fact: Matrix exponent of block diagonal matrix is also block diagonal consisting of matrix exponents of the blocks
# -- source https://math.stackexchange.com/questions/3836462/matrix-exponential-of-a-block-diagonal-matrix
# -- TODO: make it work with lower than fp32 precision (if possible in torch)
# stacking as bigger block diagonal matrix (returning to head_dim x head_dim), then sparsing
matrices = torch.block_diag(*matrices)
# batch
matrices = matrices.unsqueeze(0).repeat(x.shape[0],1,1)
# to sparse
matrices = matrices.to_sparse()
# rotating the vector through multiplication
# -- making head_dim first
x = x.permute(0, 3, 1, 2)
# NOTE: have to upcast x too because of `"bmm_sparse_cuda" not implemented for 'Half'`
with torch.autocast(device_type=str(x.device).split(':')[0] if not str(x.device).startswith('cpu') else 'cpu',enabled=False):
dtype_store = x.dtype
x = torch.bmm(matrices.float(), x.view(*[x.shape[0], self.head_dim, self.axes_length*self.spacial_dims*self.num_heads]).float())
x = x.view(*[x.shape[0], self.head_dim, self.axes_length*self.spacial_dims, self.num_heads]).permute(0, 2, 3, 1).to(dtype_store)
return x, matrices
UPD: fixed some formatting/reference mistakes and enabled matrix caching for using the same matrix for k/q rotation
UPD 2: made all matrix operations work
UPD 3: made it launchable with 1D LLaMA text generation
Training on a very toy example of shakespeare with the code above
(looks okayish, maybe it will look better when the model has more params)
@lucidrains were you able to look at the paper?
@tasansal no i haven't had the time, will take a look soon
Hello! Authors of the paper here. We're excited to see folks trying LieRE out.
Here's a minimal example of how we generated the skew symmetric matrices.
generator_raw_params = nn.Parameter(
torch.rand(
input_dimensionality,
head_dim,
head_dim,
) * 2 * math.pi
)
upper_triangle = (
torch.triu(generator_raw_params, diagonal=1)
)
skew_bases = upper_triangle - torch.transpose(upper_triangle, -1, -2)
in_basis_positions = (
positions.reshape(list(positions.shape) + [1] * 2) * skew_bases
)
generator_pos = torch.sum(in_basis_positions, dim=-3)
rotation = torch.matrix_exp(generator_pos.to(dtype=torch.float32)).to(dtype=positions.dtype)
And a longer code snippet (https://github.com/SophieOstmeier/LieRE_implementation):
It's very exciting to see someone using the method! Some notes that you may find interesting: 1) We didn't use a sparse matrix representation and instead used the right tensor shape and broadcasting to get the same effect. The slowest configuration we tried was with block size 2 (GPUs don't like lots of small matrices). That said, using a full, dense matrix never really slowed things down a measurable amount as the runtime was dominated by the quadratic component of the attention in our experiments. 2) We used the backbone in https://github.com/kentaroy47/vision-transformers-cifar10 for our experiments. We have noticed that the baseline numbers are slightly different when using the default configuration of x-transformers vs. vision-transformers-cifar10 (cls token, no patch norm, more ff dropout). 3) The performance comparisons were for training time to hit a fixed accuracy. LieRE hits the same accuracies faster than the other methods. Inference time for the same-sized model should be about the same. 4) We saw that LieRE helps more for larger models on CIFAR 100 (model size sweeps are expensive, so we didn't get a chance to sweep model sizes on the larger datasets). We hope to update the arxiv version with those experiments soon. 5) The noncommutativity means that LieRE is able to encode both absolute position information and relative position information. How are you breaking up the text into batch elements?
Hi, great job. I am the author of Lrpe. I would like to ask how the author perceives the differences between LieRE and Lrpe. Let me briefly explain Lrpe here: Lrpe points out that the decomposable multiplicative relative position encoding $W_i$ is a unitary matrix, and derives that $W_t =P \Lambda^t P^{\mathbf{H}}$, where $P$ is a unitary matrix, $\Lambda_t$ is a diagonal matrix, and ${\mathbf{H}}$ is the conjugate transpose. Under this condition, we have:
$$ q_s^{\mathbf{H}} W_s^{\mathbf{H}} W_tk_t= q_s^{\mathbf{H}} P^{\mathbf{H}} \Lambda^{-s}PP^{\mathbf{H}} \Lambda^{t} P k_t = (q_sP \Lambda^s)^{\mathbf{H}}(k_tP \Lambda^t)^{\mathbf{H}}. $$
More details can be found in the paper. It would be greatly appreciated if we could discuss the performance and theoretical differences between LieRE and Lrpe.
@lucidrains Hi lucidrains, if you find Lrpe valuable, I can submit a pull request.
Hi, @lucidrains !
There was a promising research published this month (vs. RoPE-mixed (#25) in March), the so-called LieRE positional encodings generalize the kv-vector rotation to any numbers of dimension (1D, 2D, 3D, etc....), and are much simpler than RoPE in formulation. More than that, they result in much better models accuracy and 25%+ faster training than either axial RoPE or RoPE-mixed. I think their paper was really underappreciated, and this approach will be revolutionary.
The paper is here https://arxiv.org/abs/2406.10322. The LieRE authors gave only the pseudocode for now, however it looks extremely simple.
It looks easy, but I'm a bit confused how to implement the block-diagonal skew matrix with minimal learnable components and structure preservation. (stack of n x 1D parameters + tril_indices + block matrix?) Also integrating block-sparse optimizations for fast rotations would be nice to have