HazyResearch / H3

Language Modeling with the H3 State Space Model
Apache License 2.0
511 stars 53 forks source link

FFT Conv on Seq > 8192? #14

Open darius-lam opened 1 year ago

darius-lam commented 1 year ago

In the paper FFTConv is used on sequence lengths > 8192, however this line in the cpp code has:

TORCH_CHECK(fft_size >= 16 && fft_size <= 16384 && (fft_size == 1 << int(log2(float(fft_size)))));

And since fft_size = 2 * seq_len, this effectively limits the seqlen to 8192.

How did you guys overcome this?

DanFu09 commented 1 year ago

Hi, great question! The FFTConv in this repo is a fused CUDA kernel for running general convolutions on sequences 8K or shorter - longer does not fit in SRAM, so requires more GPU HBM reads and writes. In our paper, we go to longer sequences for SSMs in particular by using a state passing algorithm (Section 4.2 in the paper).

We are working a separate release for that code soon, since we just wanted this repo to focus on the H3 architecture/weights. I will post an update to this issue when we release that code!