flatironinstitute / jax-finufft

JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
Apache License 2.0
77 stars 2 forks source link

Adding advanced interface #68

Closed dfm closed 6 months ago

dfm commented 6 months ago

This is still very much a work in progress, but it should help with #66 and #67. I'll request feedback ASAP!

dfm commented 6 months ago

@lgarrison — Here's my first pass at all this. Here's how we would provide the configuration relevant for #67:

from jax_finufft import nufft2, options

opts = options.NestedOpts(
  type1=options.Opts(gpu_method=1),
  type2=options.Opts(gpu_method=2),
)

nufft2(..., opts=opts)

or, equivalently in this case:

from jax_finufft import nufft2, options

opts = options.NestedOpts(
  forward=options.Opts(gpu_method=2),
  backward=options.Opts(gpu_method=1),
)

nufft2(..., opts=opts)

It's not all that ergonomic, but I think it's a decent start!

dfm commented 6 months ago

I've also done something with the imports to break the CUDA compilation. I think it has something to do with jax_finufft_gpu.h being included twice. We probably need to move the descriptor definition to a separate header.

lgarrison commented 6 months ago

The immediate CUDA compilation error is just that there's no declaration of the default_opts<T> function visible to jax_finufft_gpu.cc. That declaration lives in lib/jax_finufft_gpu.h, but that file can't be included as a header in multiple compilation units because it contains function definitions as well as declarations. To fix this, I did the usual thing of splitting the declarations out into a header and putting the definitions in a source file. I called them cufinufft_wrapper.h and cufinufft_wrapper.cc, since most of that file is about giving the cufinufft functions C++ wrappers. But if we don't want to fix it this way for any reason, let me know!

Some CUDA tests fail locally with a CUDA illegal memory access. Not yet sure if it's a problem with the opts, or this header refactoring.

lgarrison commented 6 months ago

The problem was indeed with the header refactoring. The y_index and z_index functions have generic templated definitions in the header, as well as template specializations in the source file. But if the specializations aren't declared in the header, then the compiler won't know it needs to look for the specializations and will just use the generic version.

I don't like my solution, it feels fragile to me! It's too easy to write a specialization in the source file that gets silently ignored. Not sure if there's a better pattern we should be using here.

dfm commented 6 months ago

Thanks for taking this down @lgarrison!! I'll take a look this afternoon.

lgarrison commented 6 months ago

I confirm the opts are working for me and fix the performance issue from #67.

dfm commented 6 months ago

Thanks @lgarrison! I think that the approach you came up with here is totally fine. I agree that it's not very elegant, but I think we should just roll with it and revisit only if we need to later. It's possible that the whole library could benefit from some refactoring, but let's not let that get in the way of merging this. With that in mind, I'm going to merge this now!

This fixes #67, but let's leave #66 open until we add info to the README.