Closed Doraemonzzz closed 1 year ago
The shape of k is (H, L) so the IO cost isn’t too bad to compute it outside. The shape of u is (B, H, L) so that dominates the IO cost. We also parallelize by B and H so it would naively require extra computation.
On Wed, Mar 8, 2023 at 11:54 AM Doraemonzzz @.***> wrote:
Thanks for your great work. Here I want to ask why fft(k) is not fused into the kernel, is it a performance issue? I mean why is it implemented as follows:
def fftconv_fast(u, k, D, dropout_mask): """Fuse padding + rfft + pointwise mult + ifft + multiply with D + gelu + dropout """ seqlen = u.shape[-1] fft_size = 2 * seqlen k_f = torch.fft.rfft(k, n=fft_size) out = fftconv_fwd(u, k_f, D, dropout_mask, fft_size) return out
instead of:
def fftconv_fast(u, k, D, dropout_mask): """Fuse padding + rfft + pointwise mult + ifft + multiply with D + gelu + dropout """ seqlen = u.shape[-1] fft_size = 2 * seqlen out = fftconv_fwd(u, k, D, dropout_mask, fft_size) return out
— Reply to this email directly, view it on GitHub https://github.com/HazyResearch/H3/issues/19, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDDIITMIV5Z6B7BQAZJDTTW3CTR3ANCNFSM6AAAAAAVT6CM34 . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Make sense. Thanks.
Thanks for your great work. Here I want to ask why fft(k) is not fused into the kernel, is it a performance issue? I mean why is it implemented as follows:
instead of: