flatironinstitute / jax-finufft

JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
Apache License 2.0
77 stars 2 forks source link

Is there a way to tweak finufft options with jax-finufft? #66

Closed AntPaa closed 6 months ago

AntPaa commented 6 months ago

Hey,

Can I access the finufft options parameters (https://finufft.readthedocs.io/en/latest/opts.html#opts) when using jax-finufft?

I'm using the jax-finufft for iterative image reconstruction and, thus, would like to squeeze every last bit of computation speed out of it. Based on my previous testing with the finufft package, the upsampfac and fftw plan options were really usefull for specific situtations.

lgarrison commented 6 months ago

There isn't currently. It would be easy to at least add support for upsampfac, since that's supported by both CPU and GPU finufft. We might want to think a little more carefully about options that are specific to one backend, like fftw for the CPU or gpu_method for the GPU; not sure if there's a standard design pattern in JAX for exposing options on a per-backend basis?

dfm commented 6 months ago

I was on my way over to say the same thing! It would be really nice to have this feature, but we don't currently support it. Some thoughts about how it might work:

We currently set up the opts for CPU here:

https://github.com/flatironinstitute/jax-finufft/blob/7d67d84afd7c4ecb89942d99742fb5a072de16af/lib/jax_finufft_cpu.cc#L20-L21

and for GPU here:

https://github.com/flatironinstitute/jax-finufft/blob/7d67d84afd7c4ecb89942d99742fb5a072de16af/lib/kernels.cc.cu#L19-L23

For static options, it shouldn't be too hard to add pointers to a finufft_opts and/or cufinufft_opts struct to our NufftDescriptor struct. Then we'd need to update build_descriptor to accept user inputs, which would be passed here:

https://github.com/flatironinstitute/jax-finufft/blob/7d67d84afd7c4ecb89942d99742fb5a072de16af/src/jax_finufft/lowering.py#L119-L121

But, implementing support for something like the FFTW plan seems less trivial to me because I believe that for the best performance we'd want to cache the plan for use between calls. One option would be to expose the plan to Python so that we could generate the plan during tracing/compile time, but that's a pretty serious refactor! I'm definitely interested in thinking/brainstorming about it!

lgarrison commented 6 months ago

We may not need to implement proper re-use of FFTW plans; a new plan within the same process will automatically re-use any wisdom accumulated from previous planning, and plan creation with wisdom is pretty fast. However, if we wanted to re-use wisdom between processes, that would require exposing a wisdom import/export mechanism.

I was also wondering what the Python API should look like; if the user calls nufft1(..., fftw=FFTW_MEASURE) but the platform is CUDA, do you think that should be ignored, or a warning/error?

Or do we add finufft_opts and cufinufft_opts as keyword args so that one can write a single function call that will use the options relevant to the resulting backend? E.g. nufft1(..., finufft_opts=dict(fftw=FFTW_MEASURE), cufinufft_opts=dict(gpu_method=1))

dfm commented 6 months ago

Great question! I think that it would be useful to explicitly set different options for each platform and just ignore the ones that aren't relevant to the current platform. I guess we could also warn, but I'd probably be fine just ignoring.

lgarrison commented 6 months ago

@AntPaa I think we'd be happy to accept a contribution here; do you have any interest in working on this? It would involve modifying both the Python and C++ parts of the code, but it's probably not a very big change overall. No worries if not!

AntPaa commented 6 months ago

I'm a total beginner with git and can't really handle C++ at all, so I'll have to politely decline on the contribution part.

dfm commented 6 months ago

@AntPaa — No worries at all. Thanks for opening this feature request regardless and hopefully we or another community member can implement it!