flatironinstitute / jax-finufft

JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
Apache License 2.0
80 stars 3 forks source link

Unreasonably slow nufft1 and nufft2 evaluations for large grid sizes #42

Closed AntPaa closed 11 months ago

AntPaa commented 11 months ago

After updating to the newest version of (CPU-only) jax-finufft, large scale nufft evaluations are waaay too slow compared to what they were before. This is easily reproducible, at least on my end, by installing newest jax-finufft and, for example, jax-finufft 0.0.2 and running the code below:

import numpy as np
import jax
from jax_finufft import nufft1, nufft2
from datetime import datetime
jax.config.update('jax_platform_name', 'cpu')  # Default to CPU

M = 100000
#N = (384, 384, 384)  # Fast as it should be
N = (512, 512, 512)  # Waaaay too slow 

x = 2 * np.pi * np.random.uniform(size=M)
y = 2 * np.pi * np.random.uniform(size=M)
z = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)

print("{} Performing NUFFT1".format(datetime.now().strftime("%H:%M:%S")))
f = nufft1(N, c, x, y, z, eps=1e-6, iflag=1)
print("{} Done".format(datetime.now().strftime("%H:%M:%S")))

print("{} Performing NUFFT2".format(datetime.now().strftime("%H:%M:%S")))
ff = nufft2(f, x, y, z, eps=1e-6, iflag=1)
print("{} Done".format(datetime.now().strftime("%H:%M:%S")))

Here, for the newer jax-finufft version the 512^3 calculation is unreasonably slow while the 384^3 is behaving as expected taking around 5 minutes and 10 seconds, respectively, on my machine. For the older jax-finufft version both versions of the above code take around 10 seconds as expected.

I have no idea if the problem is on my end, or if this is something to do with jax (as obviously the jax versions are different for the new and old jax-finufft installations) or jax-finufft, but if someone could take a closer look at this it would be appreciated.

dfm commented 11 months ago

I can't reproduce this issue. Using the GitHub main version of jax-finufft and jax v0.4.19, here is the runtime that I get for your 5123 sample code on the machines that I have access to:

In both cases, I get approximately the same runtime for both nufft1 and nufft2.

Unfortunately I don't have much to suggest besides trying to reinstall everything with a fresh environment or something like that!

AntPaa commented 11 months ago

Thanks for testing this. If nothing else, it is good to know that the problem is on my end.

I'll close this issue and continue troubleshooting.

dfm commented 11 months ago

Sounds good. Please do report back if you figure out what's going on or if you find any other clues!

lgarrison commented 11 months ago

Just reporting that I also can't reproduce this. @AntPaa, you might want to try using the finufft Python interface (i.e. without JAX) to see if your issue is with finufft or jax-finufft: https://finufft.readthedocs.io/en/latest/python.html