HazyResearch / butterfly

Butterfly matrix multiplication in PyTorch
Apache License 2.0
157 stars 32 forks source link

Code to accompany the papers Learning Fast Algorithms for Linear Transforms Using Butterfly Factorizations and Kaleidoscope: An Efficient, Learnable Representation For All Structured Linear Maps.

Requirements

python>=3.6
pytorch>=1.8
numpy
scipy

Installing the fast CUDA implementation of butterfly multiply:

To install:

python setup.py install

That is, use the setup.py file in this root directory.

An example of creating a conda environment and then installing the CUDA butterfly multiply (h/t Nir Ailon):

conda create --name butterfly python=3.8 scipy pytorch=1.8.1 cudatoolkit=11.0 -c pytorch
conda activate butterfly
python setup.py install

Usage

2020-08-03: The new interface to butterfly C++/CUDA code is in csrc and torch_butterfly. It is tested in tests/test_butterfly.py (which also shows example usage).

The file torch_butterfly/special.py shows how to construct butterfly matrices that performs FFT, inverse FFT, circulant matrix multiplication, Hadamard transform, and torch.nn.Conv1d with circular padding. The tests in tests/test_special.py show that these butterfly matrices exactly perform those operations.

Old interface

Note: this interface is being rewritten. Only use this if you need some feature that's not supported in the new interface.

The butterfly multiplication is written in C++ and CUDA as PyTorch extension. To install it:

cd butterfly/factor_multiply
python setup.py install
cd butterfly/factor_multiply_fast
python setup.py install

Without the C++/CUDA version, butterfly multiplication is still usable, but is quite slow. The variable use_extension in butterfly/butterfly_multiply.py controls whether to use the C++/CUDA version or the pure PyTorch version.

For training, we've had better results with the Adam optimizer than SGD.