mmuckley / torchkbnufft

A high-level, easy-to-deploy non-uniform Fast Fourier Transform in PyTorch.
https://torchkbnufft.readthedocs.io/
MIT License
204 stars 44 forks source link

C++ Interpolation Backend with PyTorch Extensions #28

Open mmuckley opened 3 years ago

mmuckley commented 3 years ago

I'm wondering if we might get better performance with the package by rewriting the backend to call into the PyTorch C++ API directly. This would bypass the Python interpreter, which gives quite a bit of overhead with indexing, and indexing is usually our slowest or second-slowest operation when doing interpolations. The process is documented fairly well here.

The nice thing about using the PyTorch C++ API is we've basically already written the code in C++. torch.Tensor.index_add_ becomes at::Tensor::index_add_ and so on and so forth, so we could basically keep the code we have in Python and transcribe it to C++. As long as we stay wtihin the PyTorch API we shouldn't have to worry about C++ stuff like memory management, as the PyTorch garbage collector will take care of that for us.

One thing to figure out would be distribution. I think the best thing would be a modification to setup.py that builds for any target system and then uploads it to PyPI, keeping the current pure-Python as a fallback. That is to say, we should avoid compiling anything on the install system. The user should only download wheels and immediately use it, which is what people expect from this package. We could accomplish this by modifying the distribution config to build and upload wheels for each target system. Conservatively we could build for Mac and Linux targets and have Windows as a stretch goal.

I don't have time to do this in the immediate future, but I could assist and review diffs from anyone that is interested. Feel free to post here or submit PRs if you are.