HazyResearch / H3

Language Modeling with the H3 State Space Model
Apache License 2.0
515 stars 54 forks source link

The motivation for not fusioning fff(k) into the kernel #19

Closed Doraemonzzz closed 1 year ago

Doraemonzzz commented 1 year ago

Thanks for your great work. Here I want to ask why fft(k) is not fused into the kernel, is it a performance issue? I mean why is it implemented as follows:

def fftconv_fast(u, k, D, dropout_mask):
     """Fuse padding + rfft + pointwise mult + ifft + multiply with D + gelu + dropout
     """
     seqlen = u.shape[-1]
     fft_size = 2 * seqlen
     k_f = torch.fft.rfft(k, n=fft_size)
     out = fftconv_fwd(u, k_f, D, dropout_mask, fft_size)
     return out

instead of:

def fftconv_fast(u, k, D, dropout_mask):
     """Fuse padding + rfft + pointwise mult + ifft + multiply with D + gelu + dropout
     """
     seqlen = u.shape[-1]
     fft_size = 2 * seqlen
     out = fftconv_fwd(u, k, D, dropout_mask, fft_size)
     return out
DanFu09 commented 1 year ago

The shape of k is (H, L) so the IO cost isn’t too bad to compute it outside. The shape of u is (B, H, L) so that dominates the IO cost. We also parallelize by B and H so it would naively require extra computation.

On Wed, Mar 8, 2023 at 11:54 AM Doraemonzzz @.***> wrote:

Thanks for your great work. Here I want to ask why fft(k) is not fused into the kernel, is it a performance issue? I mean why is it implemented as follows:

def fftconv_fast(u, k, D, dropout_mask): """Fuse padding + rfft + pointwise mult + ifft + multiply with D + gelu + dropout """ seqlen = u.shape[-1] fft_size = 2 * seqlen k_f = torch.fft.rfft(k, n=fft_size) out = fftconv_fwd(u, k_f, D, dropout_mask, fft_size) return out

instead of:

def fftconv_fast(u, k, D, dropout_mask): """Fuse padding + rfft + pointwise mult + ifft + multiply with D + gelu + dropout """ seqlen = u.shape[-1] fft_size = 2 * seqlen out = fftconv_fwd(u, k, D, dropout_mask, fft_size) return out

— Reply to this email directly, view it on GitHub https://github.com/HazyResearch/H3/issues/19, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDDIITMIV5Z6B7BQAZJDTTW3CTR3ANCNFSM6AAAAAAVT6CM34 . You are receiving this because you are subscribed to this thread.Message ID: @.***>

Doraemonzzz commented 1 year ago

Make sense. Thanks.