jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.26k stars 2.77k forks source link

`jax.nn.dot_product_attention` does not respect `key_value_seq_lengths` #23349

Open danjenson opened 1 month ago

danjenson commented 1 month ago

Description

Perhaps I am using this function incorrectly, but I get data leaks when using key_value_seq_lengths. It appears as though both the xla and cudnn implementations in jax nightly do not support this argument. Here is some reproducible code:

#!/usr/bin/env python3
import jax.numpy as jnp
from jax import random, nn

B, L, H, D = 8, 128, 4, 64
rng = random.key(42)
x = random.normal(rng, (B, L, H, D // H), dtype=jnp.bfloat16)
valid_lens = jnp.array([24, 125, 53, 28, 77, 96, 13, 114], jnp.int32)

def vanilla_attention(qs, ks, vs, valid_lens):
    scores = jnp.einsum("BQHD,BKHD->BHQK", qs, ks) / jnp.sqrt(D // H)
    if valid_lens is not None:
        mask = jnp.arange(L) < valid_lens[:, None]
        mask = mask[:, None, None, :]  # broadcast across H, Q in [B, H, Q, K]
        scores = jnp.where(mask, scores, -jnp.inf)
    attn = nn.softmax(scores, axis=-1)
    return jnp.einsum("BHQK,BKHD->BQHD", attn, vs).reshape(B, L, D)

def xla_attention(qs, ks, vs, valid_lens):
    ctx = nn.dot_product_attention(
        qs, ks, vs, key_value_seq_lengths=valid_lens, implementation="xla"
    )
    return ctx.reshape(B, L, D)

def cudnn_attention(qs, ks, vs, valid_lens):
    ctx = nn.dot_product_attention(
        qs, ks, vs, key_value_seq_lengths=valid_lens, implementation="cudnn"
    )
    return ctx.reshape(B, L, D)

van_attn = vanilla_attention(x, x, x, valid_lens)
xla_attn = xla_attention(x, x, x, valid_lens)
cud_attn = cudnn_attention(x, x, x, valid_lens)
print(jnp.allclose(van_attn, xla_attn, rtol=1.0, atol=1.0))  # False
print(jnp.allclose(van_attn, cud_attn, rtol=1.0, atol=1.0))  # False
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01))  # True

van_attn = vanilla_attention(x, x, x, None)
xla_attn = xla_attention(x, x, x, None)
cud_attn = cudnn_attention(x, x, x, None)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01))  # True

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.32.dev20240830
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.12.2 (main, Mar  2 2024, 09:51:01) [GCC 13.2.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='ghost', release='6.6.47_1', version='#1 SMP PREEMPT_DYNAMIC Mon Aug 19 16:42:31 UTC 2024', machine='x86_64')

$ nvidia-smi
Fri Aug 30 18:11:31 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0 Off |                  Off |
| 32%   39C    P2             78W /  480W |     393MiB /  24564MiB |      3%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     28449      C   ...ions/3.12.2/envs/jax/bin/python3.12        386MiB |
+-----------------------------------------------------------------------------------------+
superbobry commented 1 month ago

@kaixih PTAL.

kaixih commented 1 month ago

I just created a PR to fix this issue. Basically, the current API requires both query_seq_lengths and key_value_seq_lengths. This PR relaxes it. Can you take a look at it to see if it works?

From user side, you can also try explicitly provide the query_seq_lengths with a tensor filled with max seq lengths.

danjenson commented 1 month ago

The following works when I provide the max len for query lengths:

#!/usr/bin/env python3
import jax.numpy as jnp
from jax import random, nn

B, L, H, D = 8, 128, 4, 64
rng = random.key(42)
x = random.normal(rng, (B, L, H, D // H), dtype=jnp.bfloat16)
valid_lens = jnp.array([24, 125, 53, 28, 77, 96, 13, 114], jnp.int32)

def vanilla_attention(qs, ks, vs, valid_lens):
    scores = jnp.einsum("BQHD,BKHD->BHQK", qs, ks) / jnp.sqrt(D // H)
    if valid_lens is not None:
        mask = jnp.arange(L) < valid_lens[:, None]
        mask = mask[:, None, None, :]  # broadcast across H, Q in [B, H, Q, K]
        scores = jnp.where(mask, scores, -jnp.inf)
    attn = nn.softmax(scores, axis=-1)
    return jnp.einsum("BHQK,BKHD->BQHD", attn, vs).reshape(B, L, D)

def xla_attention(qs, ks, vs, valid_lens):
    if valid_lens is None:
        valid_lens = jnp.repeat(L, B)
    ctx = nn.dot_product_attention(
        qs,
        ks,
        vs,
        query_seq_lengths=jnp.repeat(L, B),
        key_value_seq_lengths=valid_lens,
        implementation="xla",
    )
    return ctx.reshape(B, L, D)

def cudnn_attention(qs, ks, vs, valid_lens):
    if valid_lens is None:
        valid_lens = jnp.repeat(L, B)
    ctx = nn.dot_product_attention(
        qs,
        ks,
        vs,
        query_seq_lengths=jnp.repeat(L, B),
        key_value_seq_lengths=valid_lens,
        implementation="cudnn",
    )
    return ctx.reshape(B, L, D)

van_attn = vanilla_attention(x, x, x, valid_lens)
xla_attn = xla_attention(x, x, x, valid_lens)
cud_attn = cudnn_attention(x, x, x, valid_lens)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01))  # True

van_attn = vanilla_attention(x, x, x, None)
xla_attn = xla_attention(x, x, x, None)
cud_attn = cudnn_attention(x, x, x, None)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01))  # True
kaixih commented 1 month ago

@danjenson Can we know if it is a typical use case for you to only provide the kv_seq_lengths?

danjenson commented 1 month ago

Constantly -- usually I want an answer to every "query" but each query can only use specific data/keys when answering that question.