HazyResearch / fly

Apache License 2.0
194 stars 22 forks source link

Pure torch implementaton of pixelfly (for TPUs, CPUs and custom blocks) #5

Open justheuristic opened 2 years ago

justheuristic commented 2 years ago

I'm not sure if it's appropriate to create issues like this, feel free to close it without warning. Otherwise, I'd request this to stay open for some time in case somebody is interested.

Why?: The two existing backends for pixelfly use either huggingface blocksparse or triton. However, these are not always available, such as when training on TPUs or using custom parameters (e.g. triton offers only a couple block sizes)

What?: Below you can find a (limited) re-implementation of pixelfly in pure pytorch. Instead of block-sparse kernels, this implementation takes advantage of the fact that butterfly layout has equal number of nonzero blocks in each row. We can take advantage of this using a two-stage procedure:

  1. compute all blocks using regular (dense) matmul with [in_features, (block_size * blocks_per_input)] weights
  2. aggregate blocks according to butterfly layout using F.embedding_bag(..., mode='sum')

Here's the implementation: https://gist.github.com/justheuristic/9e4fb81381451a4bc8cbfee0a5100eba It's heavily inspired by the original code and re_uses parts of blocksparse_linear.py

It's a single file, requires only pytorch and einops and is compatible with TPUs. The speed-ups are comparable (see example_and_tests), plus it supports custom block sizes, tf32, autocast, etc. You can also easily re-write this in tensorflow using tfa.EmbeddingBag

Feel free to use for whatever :)

JunweiLiang commented 2 years ago

Awesome work! Any recommendations on how to set butterfly_size / block_size given in_features/out_features?

tridao commented 2 years ago

It's up to your parameter budget. I usually set it so that there are 2-4 blocks.

JunweiLiang commented 2 years ago

It seems that for smaller MLPs, pixelfly underperforms PyTorch nn.Linear. Here is my test on a T4 GPU (running backwards 100 times with [16, 512, D] inputs):

exp for layers: [(384, 1536), (1536, 384)] and block_size 48 and butterfly size 64 Testing pixelfly... since start: 4.4123 seconds Testing baseline... since start: 1.8749 seconds

For larger MLP pixelfly works well: exp for layers: [(2048, 8192), (8192, 4096)] and block_size 256 and butterfly size 64 Testing pixelfly... since start: 28.1280 seconds Testing baseline... since start: 79.7593 seconds

Any thoughts on how to improve with smaller MLPs? Channel sizes of 768/384/192 are more common in Transformers' MLP.

tridao commented 2 years ago

The core computational primitive is blocksparse matmul, where we rely on either Triton or Huggingface blocksparse libraries. Those generally give some speedup over dense matmul if the fraction of nonzeros is below 30-50%.