SHI-Labs / NATTEN

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!
https://shi-labs.com/natten/
Other
336 stars 25 forks source link

Is NATTEN Fused NA v0.17 faster than Flash Attention 2 #143

Open goldhuang opened 2 months ago

goldhuang commented 2 months ago

I don't see backward speedup using NATTEN, even with only half size as kernel size when calling na3d(). I'm not sure if it's as expected. Could anyone help to clarify or confirm? Thanks!

alihassanijr commented 2 months ago

Could you please share your exact problem size (tensor shape, kernel size, dilation, causal mask)?

FNA has some key differences in terms of layout and parallelism pattern compared to FAv2, therefore it is not expected to outpeform it for every problem size. That said, it shouldn't be that much slower either.

In addition, with every extra mode comes some additional unavoidable overhead (at least on older architectures). This means NA1D will always have faster paths and less overhead than NA2D, and 3D just has more overhead than both 2D and 1D. More details about this are available in the paper. Some problem sizes just can't hide that additional overhead with the compute they save.

goldhuang commented 2 months ago

@alihassanijr Thanks!

I'm applying NATTEN on a self attention with input shape around (4, 12, 12, 20, 30, 64), with dilation=1, without causal mask. I don't see speed-up on backward even with kernel size (7, 7, 11) using na3d or kernel size (1441) using na1d, with natten.use_fused_na(True, True) and natten.use_kv_parallelism_in_fused_na(True). Does it align with what you expect?

Could you share some cases that NATTEN can speed up training?

alihassanijr commented 2 months ago

When you say you're not seeing any speed up on backward, do you mean training? Or are you profiling the backward pass separately?

Also, I'm confused as to what your baseline is here. Is it Flash Attention V2, or is it NATTEN itself / something else? Flash Attention V2 and FMHA are somewhat comparable in terms of functionality, but imo definitely not speed because of the extra level of parallelism in FAv2.

If you don't see a speed difference between using NA3D and NA1D (with flattened window size), then there's very likely something wrong.

Could you share some cases that NATTEN can speed up training?

We'll update our latest arxiv in a few weeks with numbers on the backward pass and training. We do see up to about 40% improvement in training time -- but that's just our classification models that aren't necessarily bound by attention. At op level, backward sees about 9X, 4X, and 5X the speed in 1D, 2D, and 3D respectively compared to naive NA.