Closed joseph-long closed 6 months ago
I don't have a GPU to test this on, but I can't reproduce this behavior on my CPU.
But first, one issue here is that you'll want to replace the calls to:
jax.value_and_grad(map_nufft_over_pupil_cube)
with a pre-compiled version:
value_and_grad = jax.jit(jax.value_and_grad(map_nufft_over_pupil_cube))
Otherwise, I expect you'll be dominated by tracing/compilation time. See what happens if you do that!
For reference, the timing on my machine gives:
398469537 nufft_with_grad_time=0.24668104195734486 nufft_with_grad_time/nufft_time=1.78022185899258
nufft_time=0.13608274998841807 nufft_with_grad_time=0.2466957090073265 nufft_with_grad_time/nufft_time=1.8128360062412219
nufft_time=0.13956649997271597 nufft_with_grad_time=0.24427195801399648 nufft_with_grad_time/nufft_time=1.750219128958235
nufft_time=0.13765895902179182 nufft_with_grad_time=0.24193362501682714 nufft_with_grad_time/nufft_time=1.757485504292738
nufft_time=0.13943391700740904 nufft_with_grad_time=0.24902837502304465 nufft_with_grad_time/nufft_time=1.785995691491706
nufft_time=0.13630358397495002 nufft_with_grad_time=0.2867940000141971 nufft_with_grad_time/nufft_time=2.104082604804466
nufft_time=0.14015749999089167 nufft_with_grad_time=0.2769365839776583 nufft_with_grad_time/nufft_time=1.9758955745904103
nufft_time=0.13771983300102875 nufft_with_grad_time=0.2544636250240728 nufft_with_grad_time/nufft_time=1.8476904849439661
nufft_time=0.15944166702684015 nufft_with_grad_time=0.30837995803449303 nufft_with_grad_time/nufft_time=1.9341240203075702
nufft_time=0.16091104096267372 nufft_with_grad_time=0.2543251250172034 nufft_with_grad_time/nufft_time=1.5805324699639398
So the value and grad costs about 2x the value, which is to be expected.
Perhaps @lgarrison can run some benchmarks too if my suggestion doesn't solve your problem!
Hmm. Well, I restructured the code to move the @jax.jit outside of the value_and_grad call, but I don't see any difference in timings. This is on the GPU in my FI workstation. I can reproduce the ratio you see on the CPU, though.
Here's the modified code:
import time
import jax
import numpy as np
import jax.numpy as jnp
from jax_finufft import nufft2
jax.config.update("jax_enable_x64", True)
pupil_npix = 513
n_warmup = 2
n_trials = 10
n_pupils = 50
pupil_npix = 128
fov_pix = 128
# construct pupil
xs = np.linspace(-0.5, 0.5, num=pupil_npix)
xx, yy = np.meshgrid(xs, xs)
rr = np.hypot(xx, yy)
pupil = (rr < 0.5).astype(np.complex64)
pupil_cube = jnp.repeat(pupil[jnp.newaxis, :, :], n_pupils, axis=0).astype(jnp.complex128)
# construct spatial frequency evaluation points
delta_lamd_per_pix = 0.5
idx_pix = delta_lamd_per_pix * (jnp.arange(fov_pix) - fov_pix / 2 + 0.5)
UVs = 2 * jnp.pi * (idx_pix / pupil_npix)
UU, VV = jnp.meshgrid(UVs, UVs)
def map_nufft_over_pupil_cube(pupil):
psfs = jax.vmap(nufft2, in_axes=[0, None, None])(
pupil,
UU.flatten(),
VV.flatten()
).reshape(n_pupils, fov_pix, fov_pix)
return jnp.max((psfs**2).real)
@jax.jit
def without_grad(pupil_cube):
v = map_nufft_over_pupil_cube(pupil_cube)
return v
@jax.jit
def with_grad(pupil_cube):
v, g = jax.value_and_grad(map_nufft_over_pupil_cube)(pupil_cube)
return v, g
for i in range(n_warmup):
v, g = with_grad(pupil_cube)
v.block_until_ready()
v = without_grad(pupil_cube)
v.block_until_ready()
for i in range(n_trials):
start = time.perf_counter()
v = without_grad(pupil_cube)
v.block_until_ready()
nufft_time = time.perf_counter() - start
start = time.perf_counter()
v, g = with_grad(pupil_cube)
v.block_until_ready()
nufft_with_grad_time = time.perf_counter() - start
print(f"{nufft_time=} {nufft_with_grad_time=} {nufft_with_grad_time/nufft_time=}")
Interesting! Maybe there's some memory transfer or re-ordering issues that are specific to GPU. I'll try to dig a little.
On my workstation with a crappy GPU I get slower runtimes for everything, but the extra factor for value and grad is 4x.
To help diagnose the issues, I started looking at the jaxpr for these operations:
Some notes about this. As expected, computing the value requires one nufft2
:
bc:c128[1,50,16384] = nufft2[
eps=1e-06
iflag=-1
output_shape=None
] z ba bb
and computing the value and grad requires the same nufft2
, and a nufft1
with the following signature:
cu:c128[1,50,128,128] = nufft1[
eps=1e-06
iflag=-1
output_shape=[128 128]
] cn cs ct
which seems about right.
That being said, there's a heck of a lot of reshaping, transposing, etc. that might be leading to issues!
Edited with slightly simpler jaxprs.
I was able to further isolate the issue by adding the following benchmark:
from jax_finufft import nufft2, nufft1
@jax.jit
def func(x):
return nufft2(x, UU, VV)
result = jax.block_until_ready(func(pupil_cube))
start = time.perf_counter()
for i in range(n_trials):
jax.block_until_ready(func(pupil_cube))
print(time.perf_counter() - start)
@jax.jit
def func(x):
return nufft1((pupil_npix, pupil_npix), x, UU, VV)
jax.block_until_ready(func(result))
start = time.perf_counter()
for i in range(n_trials):
jax.block_until_ready(func(result))
print(time.perf_counter() - start)
The first number printed gives the cost of doing a single nufft2
, and the second gives the cost of computing the nufft1
which is required to backprop through the initial nufft2
. On some GPUs the nufft1
seems to be a factor of ~20 slower than the nufft2
. We're not sure if that should be expected!
I've been looking into this, and what it seems to come down to is that the performance parameters just need to be tuned for the nufft1
. Setting gpu_method=1
(NU-points driven instead of shared-memory driven) makes almost the whole discrepancy go away (from 20x to 1.3x). It's possible that other tuning parameters should be changed, too, but I think this means that we should implement #66 sooner rather than later!
Another question is whether cufinufft should be able to choose better tuning parameters automatically. I don't know the answer to this, we could ask in the finufft GH.
Here's the script I was using to play around with this, using the native cufinufft interface and not jax-finufft:
Interesting - thanks for looking into this, @lgarrison! I'll try to take a first stab at #66 (and think about how that'll play with differentiation, a point that didn't come up over there yet!) later this week.
and think about how that'll play with differentiation
Oh, right; if nufft2
prefers gpu_method=2
(which it seems to), and nufft1
prefers gpu_method=1
, then we need a place to set the parameters for each method separately. Tricky!
I'm using jax-finufft to specify the exact spatial frequencies I want to sample in a 2D Fourier transform. It appears that the FINUFFT is faster than a matrix Fourier transform for my problem size, but computing the gradient is much slower than expected.
Output:
MRE: