patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Added RoPE Offset + Test #860

Open Artur-Galstyan opened 2 months ago

Artur-Galstyan commented 2 months ago

This PR follows #799 and adds the possibility for an offset in RoPE.

I've rebased my changes on the latest main branch and added one more test.

For your convenience, here's some code to test it with MHA:

Toggle me import functools import equinox as eqx import jax import jax.numpy as jnp from equinox.nn._attention import MultiheadAttention from equinox.nn._embedding import RotaryPositionalEmbedding from jaxtyping import Array, Float, Int embedding_size = 32 max_seq_length = 8 seq_length = 4 num_heads = 2 query_size = 64 class TransformerBlock(eqx.Module): rope_embeddings: RotaryPositionalEmbedding mha_attention: MultiheadAttention def __init__(self, embedding_size, max_seq_length, num_heads, query_size): self.rope_embeddings = RotaryPositionalEmbedding(embedding_size, max_seq_length) self.mha_attention = MultiheadAttention( num_heads=num_heads, query_size=query_size, key=jax.random.key(0) ) def __call__(self, query, key_, value, index): def process_heads( query_heads: Float[Array, "seq_length num_heads qk_size"], key_heads: Float[Array, "seq_length num_heads qk_size"], value_heads: Float[Array, "seq_length num_heads vo_size"], index: Int[Array, ""], ) -> tuple[ Float[Array, "seq_length num_heads qk_size"], Float[Array, "seq_length num_heads qk_size"], Float[Array, "seq_length num_heads vo_size"], ]: # index is the autoregressive index of the current token rope_partial = functools.partial(self.rope_embeddings, offset=index) query_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)(query_heads) key_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)(key_heads) return query_heads, key_heads, value_heads x = self.mha_attention( query=query, key_=key_, value=value, process_heads=functools.partial(process_heads, index=index), ) return x transformer_block = TransformerBlock( embedding_size, max_seq_length, num_heads, query_size ) transformer_block = eqx.filter_jit(transformer_block) q = jnp.ones(shape=(seq_length, query_size)) k = jnp.ones(shape=(seq_length, query_size)) v = jnp.ones(shape=(seq_length, query_size)) out = transformer_block(q, k, v, 0) out = transformer_block(q, k, v, 1) out = transformer_block(q, k, v, 2) out = transformer_block(q, k, v, 3) out = transformer_block(q, k, v, 4) out = transformer_block(q, k, v, 5)