mmuckley / torchkbnufft

A high-level, easy-to-deploy non-uniform Fast Fourier Transform in PyTorch.
https://torchkbnufft.readthedocs.io/
MIT License
200 stars 43 forks source link

torch.func (pytorch 2.0) compatibility #98

Open fzimmermann89 opened 1 month ago

fzimmermann89 commented 1 month ago

Hi,

to be able to use torch.func.* (for example, torch.func.grad) with torchkbnufft, thetorch.autograd.Functions in torchkbnufft/_autograd/interp.py would need an update to the newer calling signature. See https://pytorch.org/docs/stable/notes/extending.func.html.

In particular, the forward should no longer take a ctx argument. Instead, it should return the output and all tensors etc. that need to be saved to ctx. And a third function, setup_context should save it in ctx. Also, it would be nice to make use of vmappossible

As this would break compatibility for pytorch versions <2.0 (released before March 2023), this would either require a new torchkbnufft version and a bump in the required torch version, or an import-time switch depending on the pytorch version.

I would be willing to prepare a PR for either option.

mmuckley commented 1 month ago

Hello @fzimmermann89, this would be a very welcome contribution! Feel free to file a PR for this.

I think we can make a new version of torchkbnufft and require the newer PyTorch. I'm not a big fan of import switching.