ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[Feature] Rotary Positional Embeddings for Nomic Embed #1088

Closed zanussbaum closed 1 week ago

zanussbaum commented 1 week ago

Describe the bug Hi all, I wasn't able to get the same outputs using the nn.RoPe provided in MLX for the Nomic Embed models. I tried using both traditional=True and traditional=False but it seems like there's a slight difference in the implementations.

I ended up rewriting the Nomic version of RoPe in python. let me know if there are better ways to contribute to a metal kernel, it would be great to see that added! IIRC our RoPe version is very similar to Flash Attentions and GPT-NeoX.

I tried looking at the cpp code quickly but wasn't able to find where the implementations really differ. Let me know how I can help!

To Reproduce

Include code snippet

class NomicRoPE(nn.Module):
    def __init__(self, dims, base=10000.0, scale=1.0, offset=0):
        super().__init__()
        self.dims = dims
        self.base = base
        self.scale = scale
        self.offset = offset

    def __call__(self, x, offset=0):
        if x.ndim < 3:
            raise ValueError(f"[rope] Input must have at least 3 dimensions but got input with {x.ndim} dimensions.")

        scale = self.scale
        base = self.base
        dims = self.dims

        shape = x.shape
        N = x.shape[1] + offset

        # Compute sines and cosines
        half_dims = dims // 2
        positions = mx.arange(offset, N, dtype=x.dtype) * scale
        freqs = mx.arange(0, half_dims, dtype=x.dtype)
        freqs = mx.exp(-freqs * mx.array(base).log() / half_dims)
        theta = mx.expand_dims(positions, 1) * mx.expand_dims(freqs, 0)
        coss = mx.cos(theta)
        sins = mx.sin(theta)

        def rotate_half(x):
            x1, x2 = mx.split(x, 2, axis=-1)
            return mx.concatenate((-x2, x1), axis=-1)

        def apply_rope(x, coss, sins):
            return x * coss + rotate_half(x) * sins

        out_s = list(x.shape)
        out_s[-1] = half_dims
        out_s[-1] = dims
        repeated_cos = mx.tile(coss, (1, 2))
        coss = repeated_cos.reshape(coss.shape[:-1] + (1, 2 * coss.shape[-1]))

        repeated_sin = mx.tile(sins, (1, 2))
        sins = repeated_sin.reshape(sins.shape[:-1] + (1, 2 * sins.shape[-1]))

        out = apply_rope(x, coss, sins)

        return mx.reshape(out, shape)

for PR here: https://github.com/taylorai/mlx_embedding_models/pull/4 Expected behavior A clear and concise description of what you expected to happen.

Desktop (please complete the following information):

Additional context Add any other context about the problem here.

angeloskath commented 1 week ago

I think the implementations are the same with the only difference that the above implementation expects the input to be batch size x sequence_length x heads x dims while nn.RoPE expects it to be batch size x heads x sequence length x dims.

For instance the following code showcases that they produce the same results:

import mlx.core as mx
import mlx.nn as nn

class NomicRoPE(nn.Module):
    ...

x = mx.ones((100, 100))
rr = nn.RoPE(100)
nr = NomicRoPE(100)

assert (rr(x[None, None]).squeeze() - nr(x[None, :, None]).squeeze()).abs().max() < 1e-4

Let me know if I made a mistake, otherwise feel free to close the issue :-)

zanussbaum commented 1 week ago

AH! that's definitely it thank you. Is it possible to add something like this to the documentation? It wasn't obvious from looking here or at the code that it needs a different input shape (although I should have checked in hindsight :D )