nshepperd / flash_attn_jax

JAX bindings for Flash Attention v2
BSD 3-Clause "New" or "Revised" License
62 stars 0 forks source link

High difference between values from vanilla attention and flash_mha #3

Closed VachanVY closed 3 months ago

VachanVY commented 3 months ago

Hi, I observed that my language model wasn't converging (it did converge with vanilla_att) so ran the below

B = 128; h = 6; T = 256; dim = 288//6; shape=(B, h, T, dim)
q = jrand.uniform(jrand.PRNGKey(2323), shape=shape, dtype=jnp.float16) # (B, h, T, dim)
k = jrand.uniform(jrand.PRNGKey(323232), shape=shape, dtype=jnp.float16)
v = jrand.uniform(jrand.PRNGKey(2323221111), shape=shape, dtype=jnp.float16)

def vanilla_att(q, k, v, is_causal=True):
    att_wei = (q @ jnp.matrix_transpose(k))/(dim**0.5) # (B, h, T, T) <= (B, h, T, dim) @ (B, h, T, dim).transpose(2, 3)
    # causal mask
    if is_causal:
        att_wei + jnp.triu(jnp.full(shape=(1, 1, T, T), fill_value=-jnp.inf), k=1)[:, :, :T, :T] # (B, h, T, T)
    att_wei = jax.nn.softmax(att_wei, axis=-1) # (B, h, T, T)
    # apply attention weights to v
    att_out = att_wei @ v # (B, h, T, T) @ (B, h, T, dv) => (B, h, T, dv)
    return att_out

diff = abs(flash_mha(q, k, v, is_causal=True, softmax_scale=dim**-0.5)-vanilla_att(q, k, v))
>>> jnp.max(diff), jnp.min(diff), jnp.mean(diff)
(Array(0.5703, dtype=float16),
 Array(0., dtype=float16),
 Array(0.146, dtype=float16))
nshepperd commented 3 months ago

Oh! I see the problem. This confused me for a bit. Flash attention accepts inputs in NTHD order (aka [n, l, h, d] as I wrote in the readme). You're comparing it against a NHTD vanilla attention. You need to put some transpose (or remove some) in your model.

VachanVY commented 3 months ago

After Transposing...

diff = abs(flash_mha(
    q.transpose((0,2,1,3)), # (B, T, h, dim)
    k.transpose((0,2,1,3)), # (B, T, h, dim)
    v.transpose((0,2,1,3)), # (B, T, h, dim)
    is_causal=True, softmax_scale=dim**-0.5
    ).transpose((0,2,1,3) # (B, h, T, dim) <= (B, T, h, dim)
                )-vanilla_att(q, k, v)) # (B, h, T, dim)

>>> jnp.max(diff), jnp.min(diff), jnp.mean(diff)
(Array(0.551, dtype=float16),
 Array(0., dtype=float16),
 Array(0.02173, dtype=float16))
nshepperd commented 3 months ago

Oh, you also have a bug here:

 att_wei + jnp.triu(jnp.full(shape=(1, 1, T, T), fill_value=-jnp.inf), k=1)[:, :, :T, :T] # (B, h, T, T)

this should be

att_wei += jnp.triu(jnp.full(shape=(1, 1, T, T), fill_value=-jnp.inf), k=1)[:, :, :T, :T] # (B, h, T, T)

your test passes on my desktop with this

VachanVY commented 3 months ago

whoops! ya sry, thanks.

And why is flash_mha not supported for float32?

nshepperd commented 3 months ago

The authors of the paper didn't implement it for float32, probably mainly because it would need 2x more sm memory and is also slower (memory is in pretty short supply afaik). You could probably use it in a float32 model by casting to bf16 and back?

nshepperd commented 3 months ago

Thanks for your interest in my repo btw!

VachanVY commented 3 months ago

And should I learn C++ Cuda or can I implement Flash Attention from scratch using the Python version of Cuda and what's the best learning resource to learn it (python or C++ version)? Thanks!

(BTW are you a student?)

nshepperd commented 3 months ago

Join the cuda mode discord and watch the recorded lectures! I'm not sure that i can say what you should learn but i think c++ cuda is quite interesting and a good way to understand the hardware.

I'm not a student. I'm, errr, a researcher (neet, lol).