HazyResearch / safari

Convolutions for Sequence Modeling
Apache License 2.0
869 stars 71 forks source link

Question about long sequence lengths with Hyena #10

Open VivekPanyam opened 1 year ago

VivekPanyam commented 1 year ago

Hello!

In the Hyena paper, section 4.4 says "Hyena speedups reach 100x at sequence length 64K."

The figure referenced by that section (figure 4.3) stops at a sequence length short of 10k and the optimized implementation in this repo appears to be limited to an 8k sequence length.

There are a few other references to a 100x speedup over FlashAttention in the paper (and in blog posts). Are these measured speedups or extrapolated from smaller sequence lengths?

I've experimented with the implementation in standalone_hyena.py but it appears to be ~3x slower than FlashAttention at sequence lengths > 32k tokens.

Do you have an estimate for when the fftconv implementation in this repo will support longer sequence lengths (or a pointer to another Hyena codebase if the speedups in the paper were measured)?

Thanks for the great work!

Zymrael commented 1 year ago

The runtime numbers in the paper do not use the optimized fftconv kernel precisely because of the temporary 8k limitation.

The figure referenced by that section (figure 4.3) stops at a sequence length short of 10k

Figure 4.3 (left) goes up to 100k, I think you're referencing the one on the right (which is only a zoomed-in portion of the left figure)?

I've experimented with the implementation in standalone_hyena.py but it appears to be ~3x slower than FlashAttention at sequence lengths > 32k tokens.

Can you give more details on your benchmarking workload? Hyena should already be much faster at 32k tokens, I suspect there might be other factors at play.

VivekPanyam commented 1 year ago

Figure 4.3 (left) goes up to 100k, I think you're referencing the one on the right (which is only a zoomed-in portion of the left figure)?

Yeah my bad. I saw 10^5 and thought 10000 for some reason :facepalm:

Can you give more details on your benchmarking workload?

My benchmarking baseline was a model with:

I compared it to a model replacing FlashAttention with a HyenaOperator (d_model of 512, l_max of the sequence length above and everything else at the default values).

I ran into memory issues using the HyenaOperator with an embedding dim of 1024 so I had to drop to 512.

Even with the much smaller embedding dim, the network with the HyenaOperator was ~3x slower than the one with FlashAttention.

Do you think the memory usage issues are just an artifact of the standalone implementation?

Hyena should already be much faster at 32k tokens, I suspect there might be other factors at play.

Are you saying that the code in standalone_hyena.py should be much faster than FlashAttention at that sequence length?

The benchmark above was just a quick test to get some rough numbers so there were some other differences between the baseline and the Hyena test:

Do you have any suggestions for things to try?

Thanks!

Zymrael commented 1 year ago

Thanks for all the info! I pushed a small benchmarking script here for both forward and backward passes, what numbers do you see when you run it? On my end (on a single A100) I see Hyena as 5x/6x faster at batch size 1 and seqlen 32k. If you use the same script and benchmark at batch size 64, you should get the FlashAttention runtime numbers of the original paper.

Regarding memory: yes at the moment without the custom kernel the memory scaling is slightly worse for Hyena w.r.t FlashAttention, though they are both linear. Doing a bit more recomputation on the backward pass helps, we're working on these optimizations.

VivekPanyam commented 1 year ago

I was just about to run a benchmark I wrote when you posted your comment :)

I modified your script to import from standalone_hyena and I can roughly reproduce your results on an A100. FlashAttention (fwd + bwd) takes ~3.8x longer than Hyena (fwd + bwd) at a seq len of 32k and batch size of 1.

Full output:

2048
/home/ubuntu/standalone_hyena.py:17: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31.)
  k_f = torch.fft.rfft(k, n=fft_size) / fft_size
4096
8192
16384
32768
65536
131072
---
{2048: 0.0002955092999854969, 4096: 0.0005243146999873716, 8192: 0.0016268614999717101, 16384: 0.0051274498999646315, 32768: 0.017562666200001333, 65536: 0.06848136959997646, 131072: 0.2707951788000173}
{2048: 0.000688433300001634, 4096: 0.001220289100001537, 8192: 0.003820151900026758, 16384: 0.011008315900016895, 32768: 0.041841238200004224, 65536: 0.16533355219999066, 131072: 0.6602845148999676}
---
{2048: 0.0008685043000241421, 4096: 0.0009082306999971479, 8192: 0.0017019433000314166, 16384: 0.002847423499997603, 32768: 0.00549896880002052, 65536: 0.01111297390002619, 131072: 0.02508695080000507}
{2048: 0.0017592959000012343, 4096: 0.0017505167999843252, 8192: 0.003214542400019127, 16384: 0.005148207100000945, 32768: 0.009867695800039655, 65536: 0.02085177230001136, 131072: 0.044432145000018866}

That said, the speed difference at 64k is still "only" 7.3x vs the 100x from the paper. Any thoughts on what could be causing that?

Thanks again!

Zymrael commented 1 year ago

Awesome! It's all a game of batch sizes, try running at batch sizes 16, 32 and 64 and you should see the speedup get larger.

VivekPanyam commented 1 year ago

Hmm that doesn't seem to work.

At a batch size of 16 and sequence length of 32k, FlashAttention takes 3.48 times longer than Hyena (see details below).

Thoughts:

  1. Currently, a single HyenaOperator with a batch size of 16 and seq len of 32k uses up almost all GPU memory (> 35gb on a 40GB A100) when running the benchmark. Did you have to use checkpointing to test larger combos?

  2. At that sequence length, Hyena with a batch size of 16 is ~18.7x slower than Hyena with a batch size of one. This seems to imply that batching is worse than just serially processing each sequence.

As far as I can tell, there isn't really a good reason to use a large batch size here vs a batch size of 1 + gradient accumulation.

Any ideas on what's going on?


Details:

GPU: A100 40GB SXM4 Versions:

Torch 2.0.0
flash-attn 0.2.8
einops 0.6.0

Hyena implementation permalink

Benchmark permalink

Batch size of 1

Bechmark script with batch size of 1:

2048
/home/ubuntu/standalone_hyena.py:17: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31.)
  k_f = torch.fft.rfft(k, n=fft_size) / fft_size
4096
8192
16384
32768
65536
131072
---
{2048: 0.0003016873999968084, 4096: 0.0004113602999950672, 8192: 0.0012788771000032284, 16384: 0.0046683007999945405, 32768: 0.017532283400009875, 65536: 0.06853519320000032, 131072: 0.2708130353000115}
{2048: 0.0005378262000022005, 4096: 0.0009546344999989742, 8192: 0.0030730292000043847, 16384: 0.010984528699987095, 32768: 0.041858487700005755, 65536: 0.16538351169999715, 131072: 0.6592029120000007}
---
{2048: 0.0008691948000205229, 4096: 0.0009022319999985485, 8192: 0.0014434583000138446, 16384: 0.002845392300014282, 32768: 0.005500168399998983, 65536: 0.011115612699995836, 131072: 0.02511233400000492}
{2048: 0.001425099199991564, 4096: 0.0014490745000102835, 8192: 0.0027972210999905656, 16384: 0.005150951700011319, 32768: 0.009863483899994207, 65536: 0.0208421756000007, 131072: 0.044432989299980366}

Batch size of 16

Benchmark script with batch size of 16 (limited to seq len of 32k because Hyena runs out of memory at larger seq lengths):

2048
/home/ubuntu/standalone_hyena.py:17: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31.)
  k_f = torch.fft.rfft(k, n=fft_size) / fft_size
4096
8192
16384
32768
---
{2048: 0.0020986096000342514, 4096: 0.0062191390999942085, 8192: 0.02071911389998604, 16384: 0.0747385351000048, 32768: 0.2817839461999938}
{2048: 0.004406064900013007, 4096: 0.014036858499957816, 8192: 0.049613316399972976, 16384: 0.1854088593999677, 32768: 0.7167293173999951}
---
{2048: 0.004067870600010792, 4096: 0.008555072500030292, 8192: 0.017019115599987346, 16384: 0.0345663336999678, 32768: 0.069258173299977}
{2048: 0.009594092299994372, 4096: 0.019340128099975117, 8192: 0.04256208110000444, 16384: 0.09568457360001048, 32768: 0.21756787399999666}
VivekPanyam commented 1 year ago

Something else I noticed is that the paper says "Hyena speedups reach 100x at sequence length 64K" and references Figure 4.3, but if you look at the LaTeX for Figure 4.3, it's actually only an 11.4x difference.

I know the paper is still a draft so is the figure (or text) outdated? Or are we interpreting the meaning of "speedup" differently?

Thanks!


Figure 4.3 from the paper:

image

\addplot [line width=1pt, indianred]
table {%
1024 0.9
2048 1.16
4096 1.47
8192 1.5
16384 2.84
32768 5.41
65536 11.32
};
\addplot [line width=1pt, cornflowerblue]
table {%
1024 0.4
2048 1.25
4096 2.16
8192 6.17
16384 21.74
32768 90.71
};
\addplot [line width=1pt, lightseagreen, dashed]
table {%
1024 0.29
2048 0.3
4096 0.63
8192 2.1
16384 8.33
32768 32.85
65536 129.07
};

(FlashAttention at 64k) / (Hyena at 64k) = 129.07/11.32 = ~11.4

Section 4.4 says (emphasis mine):

We benchmark runtime of an order 2 Hyena operator compared to attention and FlashAttention layers (Dao et al., 2022b). Hyena uses a fused CUDA kernel to perform FFTConv (Dao et al., 2022c). We set batch size to 64 and measure runtime (in milliseconds). Results are provided in Figure 4.3. Hyena speedups reach 100× at sequence length 64K. Crossover points for Hyena and attention is at length 2048, and for Hyena and FlashAttention is between 4096 and 8196. Despite the absolute reduction in FLOPs, speedups are achieved only on longer sequences when the gap grows sufficiently large. This occurs because hardware utilization of Hyena is lower than FlashAttention. We expect the gap between theoretical maximum speedup to shrink with improved implementations of FFTConv and specialized hardware.

Zymrael commented 1 year ago

Interesting finds, a few things here:

FlashAttention: {(1, 32768): 0.011, (4, 32768): 0.048, (8, 32768): 0.098, (16, 32768): 0.1977, (32, 32768): 0.4041, (64, 32768): 0.8306}    
Hyena: {(1, 32768): 0.0012, (4, 32768): 0.0024, (8, 32768): 0.0043, (16, 32768): 0.0082, (32, 32768): 0.0160, (64, 32768): 0.0346}  

If you plan to run models at DIM=768, SEQ_LEN=64k

FlashAttention:  {(1, 65536): 0.0745, (4, 65536): 0.2999, (8, 65536): 0.6023, (16, 65536): 1.2074}
Hyena: {(1, 65536): 0.0118, (4, 65536): 0.0406, (8, 65536): 0.0807, (16, 65536): 0.1806}                                                                                                                                                                                                  
VivekPanyam commented 1 year ago

Thanks! That makes sense. I think it would be super useful to have a sweep over (dim, batch_size, seq_len) comparing FlashAttention and Hyena runtimes (for both forward and backward passes), but I don't think I'll be able to get to that anytime soon. Do you think you'll have time to run that sweep? It might even be worth committing to the repo or adding to a wiki so there's a place for people to quickly see potential speedups.

Thanks again!