Closed Artur-Galstyan closed 2 months ago
I agree that this is not a bug per-se but still a shortcoming in our current version, especially when compared to a PyTorch counterpart which includes the input_pos
in its forward
function.
My concern here is that in order to use it, we'd need to make a backward-incompatible change to MHA. (To pass the index.) Since we've decided against implementing KV caching in MHA, then I don't think this really comes up.
In the example I provided, it's even possible to make it work with the offset even without changing the MHA implementation by using functools.partial
. TBH, I didn't think it would work and that it would re-JIT on every call when using this:
x = self.mha_attention(
query=query,
key_=key_,
value=value,
process_heads=functools.partial(process_heads, index=index), # <--
)
But I guess the compiler is pretty smart!
Ok, so I've removed the max_seq_length
argument, it's a breaking change with no backwards compatibility but more importantly, because it actually provides no benefit.
max_seq_length
(and then applying a mask to effectively only "use" the intended seq_length
). The input array can have any seq_length
, which means the module will always be re-JITted if the seq_length
of the input array changes. Thus, the max_seq_length
argument is obsolete. In other words, if we were to ever include the max_seq_length
argument, we would also need to include not just the offset
but also the cutoff
position and then apply a mask on the input array. But that would be an even breakier change 😄 .(Though, personally, I don't mind breaking changes as long as they are an improvement and properly communicated)
I actually caught another bug in the implementation, which I hadn't noticed before. Our previous implementation grouped the sin/cos like this:
[cos1, cos2, cos3, cos4, sin1, sin2, sin3, sin4]
(all grouped)
but it should have been interleaved like this:
[cos1, sin1, cos2, sin2, cos3, sin3, cos4, sin4]
I used the implementation from lucidrains as a reference for the expected values and updated the test accordingly.
2 problems are left to fix:
1) Pyright complains at the TODO
spots because ArrayLike
might be a complex number and Operator ">" not supported for types "complex" and "int"
:(
2) It keeps re-jitting when using integers for the offset and when using arrays we get a ConcretizationTypeError
. There is an almost MVP if you're curious. I'll invest more time into this - I feel like there should be a good solution.
I think both interleaved and non-interleaved are acceptable. In fact when we first wrote this, I checked this against the ESM2 implementation here to ensure correctness.
We can't switch this now for backward compatibility.
Bother, this got autoclosed because the target branch got merged. Feel free to re-open.
No worries, I'll fix it once I get back from vacation.
As mentioned in #704, this is the fix for the RoPE embeddings. I also added
pyright: ignore
to thetest_nn.py
file. I think perhaps someone forgot to run the PCH hook. I checked GH and it seems we don't run the Pyright check anymore?