Open danjenson opened 1 month ago
@kaixih PTAL.
I just created a PR to fix this issue. Basically, the current API requires both query_seq_lengths
and key_value_seq_lengths
. This PR relaxes it. Can you take a look at it to see if it works?
From user side, you can also try explicitly provide the query_seq_lengths
with a tensor filled with max seq lengths.
The following works when I provide the max len for query lengths:
#!/usr/bin/env python3
import jax.numpy as jnp
from jax import random, nn
B, L, H, D = 8, 128, 4, 64
rng = random.key(42)
x = random.normal(rng, (B, L, H, D // H), dtype=jnp.bfloat16)
valid_lens = jnp.array([24, 125, 53, 28, 77, 96, 13, 114], jnp.int32)
def vanilla_attention(qs, ks, vs, valid_lens):
scores = jnp.einsum("BQHD,BKHD->BHQK", qs, ks) / jnp.sqrt(D // H)
if valid_lens is not None:
mask = jnp.arange(L) < valid_lens[:, None]
mask = mask[:, None, None, :] # broadcast across H, Q in [B, H, Q, K]
scores = jnp.where(mask, scores, -jnp.inf)
attn = nn.softmax(scores, axis=-1)
return jnp.einsum("BHQK,BKHD->BQHD", attn, vs).reshape(B, L, D)
def xla_attention(qs, ks, vs, valid_lens):
if valid_lens is None:
valid_lens = jnp.repeat(L, B)
ctx = nn.dot_product_attention(
qs,
ks,
vs,
query_seq_lengths=jnp.repeat(L, B),
key_value_seq_lengths=valid_lens,
implementation="xla",
)
return ctx.reshape(B, L, D)
def cudnn_attention(qs, ks, vs, valid_lens):
if valid_lens is None:
valid_lens = jnp.repeat(L, B)
ctx = nn.dot_product_attention(
qs,
ks,
vs,
query_seq_lengths=jnp.repeat(L, B),
key_value_seq_lengths=valid_lens,
implementation="cudnn",
)
return ctx.reshape(B, L, D)
van_attn = vanilla_attention(x, x, x, valid_lens)
xla_attn = xla_attention(x, x, x, valid_lens)
cud_attn = cudnn_attention(x, x, x, valid_lens)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01)) # True
van_attn = vanilla_attention(x, x, x, None)
xla_attn = xla_attention(x, x, x, None)
cud_attn = cudnn_attention(x, x, x, None)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01)) # True
@danjenson Can we know if it is a typical use case for you to only provide the kv_seq_lengths?
Constantly -- usually I want an answer to every "query" but each query can only use specific data/keys when answering that question.
Description
Perhaps I am using this function incorrectly, but I get data leaks when using
key_value_seq_lengths
. It appears as though both thexla
andcudnn
implementations in jax nightly do not support this argument. Here is some reproducible code:System info (python version, jaxlib version, accelerator, etc.)