HazyResearch / butterfly

Butterfly matrix multiplication in PyTorch
Apache License 2.0
164 stars 31 forks source link

Fast inference #5

Closed mleszczy closed 5 years ago

mleszczy commented 5 years ago

I added two optimizations for inference for multiply_untied in a new function called multiply_untied_eval, both opts use avx vectorization so require avx2 to be available and assume we are using floats.

  1. Vectorization over the batch for large batch sizes---the previous code was slower than linear layers with large batch sizes, and this puts us at about 3x faster than a GEMM call for n=512, batch size=16. We use 256b registers so 8 batch values are computed at once. Left over batches are handled individually -- this turned out to be a little faster than calling padding functions in python, but also means there is no benefit if the batch size is less than 8. In order to do this vectorization and have stride=1 for the batches, the input data had to be permuted.

  2. Vectorization over the twiddle factors. This takes the current code and vectorizes it's inner loops, taking advantage of the fact that once the stride >= 8, then we can load 8 contiguous input values at a time (16 if counting total with the stride). This led to just under a 2x speedup over the current multiply untied for n=512, batch size 1. Right now I put these optimizations in the same function and call (1) if batch size is >= 8, and (2) otherwise. This seemed to have the best timing results, and the benchmark script is included.