patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Added the offset to RoPE embedding and fixed the pre-commit pyright #799

Closed Artur-Galstyan closed 2 months ago

Artur-Galstyan commented 3 months ago

As mentioned in #704, this is the fix for the RoPE embeddings. I also added pyright: ignore to the test_nn.py file. I think perhaps someone forgot to run the PCH hook. I checked GH and it seems we don't run the Pyright check anymore?

Artur-Galstyan commented 3 months ago

I agree that this is not a bug per-se but still a shortcoming in our current version, especially when compared to a PyTorch counterpart which includes the input_pos in its forward function.

My concern here is that in order to use it, we'd need to make a backward-incompatible change to MHA. (To pass the index.) Since we've decided against implementing KV caching in MHA, then I don't think this really comes up.

In the example I provided, it's even possible to make it work with the offset even without changing the MHA implementation by using functools.partial. TBH, I didn't think it would work and that it would re-JIT on every call when using this:

        x = self.mha_attention(
            query=query,
            key_=key_,
            value=value,
            process_heads=functools.partial(process_heads, index=index), # <--
        )

But I guess the compiler is pretty smart!

Artur-Galstyan commented 3 months ago

Ok, so I've removed the max_seq_length argument, it's a breaking change with no backwards compatibility but more importantly, because it actually provides no benefit.

Unfold, if you want to know Initially, I added it, because I wanted to avoid re-JITting. But in the current implementation, we are not enforcing the input array to the RoPE module to be of length max_seq_length (and then applying a mask to effectively only "use" the intended seq_length). The input array can have any seq_length, which means the module will always be re-JITted if the seq_length of the input array changes. Thus, the max_seq_length argument is obsolete. In other words, if we were to ever include the max_seq_length argument, we would also need to include not just the offset but also the cutoff position and then apply a mask on the input array. But that would be an even breakier change 😄 .(Though, personally, I don't mind breaking changes as long as they are an improvement and properly communicated)
Artur-Galstyan commented 2 months ago

I actually caught another bug in the implementation, which I hadn't noticed before. Our previous implementation grouped the sin/cos like this:

[cos1, cos2, cos3, cos4, sin1, sin2, sin3, sin4] (all grouped)

but it should have been interleaved like this:

[cos1, sin1, cos2, sin2, cos3, sin3, cos4, sin4]

I used the implementation from lucidrains as a reference for the expected values and updated the test accordingly.

2 problems are left to fix:

1) Pyright complains at the TODO spots because ArrayLike might be a complex number and Operator ">" not supported for types "complex" and "int" :(

2) It keeps re-jitting when using integers for the offset and when using arrays we get a ConcretizationTypeError. There is an almost MVP if you're curious. I'll invest more time into this - I feel like there should be a good solution.

patrick-kidger commented 2 months ago

I think both interleaved and non-interleaved are acceptable. In fact when we first wrote this, I checked this against the ESM2 implementation here to ensure correctness.

We can't switch this now for backward compatibility.

patrick-kidger commented 2 months ago

Bother, this got autoclosed because the target branch got merged. Feel free to re-open.

Artur-Galstyan commented 2 months ago

No worries, I'll fix it once I get back from vacation.