Closed VachanVY closed 3 months ago
Oh! I see the problem. This confused me for a bit. Flash attention accepts inputs in NTHD order (aka [n, l, h, d]
as I wrote in the readme). You're comparing it against a NHTD vanilla attention. You need to put some transpose (or remove some) in your model.
After Transposing...
diff = abs(flash_mha(
q.transpose((0,2,1,3)), # (B, T, h, dim)
k.transpose((0,2,1,3)), # (B, T, h, dim)
v.transpose((0,2,1,3)), # (B, T, h, dim)
is_causal=True, softmax_scale=dim**-0.5
).transpose((0,2,1,3) # (B, h, T, dim) <= (B, T, h, dim)
)-vanilla_att(q, k, v)) # (B, h, T, dim)
>>> jnp.max(diff), jnp.min(diff), jnp.mean(diff)
(Array(0.551, dtype=float16),
Array(0., dtype=float16),
Array(0.02173, dtype=float16))
Oh, you also have a bug here:
att_wei + jnp.triu(jnp.full(shape=(1, 1, T, T), fill_value=-jnp.inf), k=1)[:, :, :T, :T] # (B, h, T, T)
this should be
att_wei += jnp.triu(jnp.full(shape=(1, 1, T, T), fill_value=-jnp.inf), k=1)[:, :, :T, :T] # (B, h, T, T)
your test passes on my desktop with this
whoops! ya sry, thanks.
And why is flash_mha not supported for float32?
The authors of the paper didn't implement it for float32, probably mainly because it would need 2x more sm memory and is also slower (memory is in pretty short supply afaik). You could probably use it in a float32 model by casting to bf16 and back?
Thanks for your interest in my repo btw!
And should I learn C++ Cuda or can I implement Flash Attention from scratch using the Python version of Cuda and what's the best learning resource to learn it (python or C++ version)? Thanks!
(BTW are you a student?)
Join the cuda mode discord and watch the recorded lectures! I'm not sure that i can say what you should learn but i think c++ cuda is quite interesting and a good way to understand the hardware.
I'm not a student. I'm, errr, a researcher (neet, lol).
Hi, I observed that my language model wasn't converging (it did converge with
vanilla_att
) so ran the below