Open justheuristic opened 2 years ago
Awesome work! Any recommendations on how to set butterfly_size / block_size given in_features/out_features?
It's up to your parameter budget. I usually set it so that there are 2-4 blocks.
It seems that for smaller MLPs, pixelfly underperforms PyTorch nn.Linear. Here is my test on a T4 GPU (running backwards 100 times with [16, 512, D] inputs):
exp for layers: [(384, 1536), (1536, 384)] and block_size 48 and butterfly size 64 Testing pixelfly... since start: 4.4123 seconds Testing baseline... since start: 1.8749 seconds
For larger MLP pixelfly works well: exp for layers: [(2048, 8192), (8192, 4096)] and block_size 256 and butterfly size 64 Testing pixelfly... since start: 28.1280 seconds Testing baseline... since start: 79.7593 seconds
Any thoughts on how to improve with smaller MLPs? Channel sizes of 768/384/192 are more common in Transformers' MLP.
The core computational primitive is blocksparse matmul, where we rely on either Triton or Huggingface blocksparse libraries. Those generally give some speedup over dense matmul if the fraction of nonzeros is below 30-50%.
I'm not sure if it's appropriate to create issues like this, feel free to close it without warning. Otherwise, I'd request this to stay open for some time in case somebody is interested.
Why?: The two existing backends for pixelfly use either huggingface blocksparse or triton. However, these are not always available, such as when training on TPUs or using custom parameters (e.g. triton offers only a couple block sizes)
What?: Below you can find a (limited) re-implementation of pixelfly in pure pytorch. Instead of block-sparse kernels, this implementation takes advantage of the fact that butterfly layout has equal number of nonzero blocks in each row. We can take advantage of this using a two-stage procedure:
[in_features, (block_size * blocks_per_input)]
weightsF.embedding_bag(..., mode='sum')
Here's the implementation: https://gist.github.com/justheuristic/9e4fb81381451a4bc8cbfee0a5100eba It's heavily inspired by the original code and re_uses parts of blocksparse_linear.py
It's a single file, requires only pytorch and einops and is compatible with TPUs. The speed-ups are comparable (see example_and_tests), plus it supports custom block sizes, tf32, autocast, etc. You can also easily re-write this in tensorflow using tfa.EmbeddingBag
Feel free to use for whatever :)