HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
46 stars 6 forks source link

Self-Convolution #58

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

A couple of months ago, we tested a global convolution "self_conv": https://github.com/HomebrewNLP/HomebrewNLP-Jax/blob/97b7e8def9a676b0e5d44a3fa4aaaa826c54fc10/src/model.py#L258-L270. However, due to unusably slow speeds, this self-convolution was never usable. While the idea of every position generating weights for every other position (like in self-attention) is appealing, it also doesn't scale well with increased sequence length and causes significant computation overheads. Compared to our usual convolution with a kernel size of 5, this convolution can be up to 50,000x slower (with a sequence length of 256ki).\ This issue exists to discuss and gather ideas and potentially re-benchmark a faster variant of Self-Convolution.

ClashLuke commented 2 years ago

Others have demonstrated that it's possible to use the convolution theorem to optimize the runtime drastically. One notable paper of this sort is CKConv, which talks about the "continuous kernel convolution" which computes a similar function to MLP Mixer but without the added overhead of computing a massive square matrix multiplication. Instead, they propose to use ifft(fft(x) * fft(kernel)), where fft and and ifft each run in log2(n) operations. With this simple change of implementation, the runtime complexity of a global depthwise convolution gets reduced from O(n^2) (as in attention) to O(n * log(n)).\ Unfortunately, this might still not be enough as other papers such as FNet have demonstrated that FFT is slower than a square matrix multiplication up to a sequence length of 8192. However, our sequences have up to 2 million tokens, so this might not be an issue.

ClashLuke commented 2 years ago

FFT is not sensible on TPU. For now, the underlying problem is solved by #75. We'll have to reopen this issue if things change.