flatironinstitute / jax-finufft

JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
Apache License 2.0
77 stars 2 forks source link

Computing gradient of a nufft2 costs 13x more than the nufft2 alone on GPU #67

Closed joseph-long closed 6 months ago

joseph-long commented 6 months ago

I'm using jax-finufft to specify the exact spatial frequencies I want to sample in a 2D Fourier transform. It appears that the FINUFFT is faster than a matrix Fourier transform for my problem size, but computing the gradient is much slower than expected.

Output:

$ python tiny_jax_finufft_comparison.py 
nufft_time=0.005215979181230068 nufft_with_grad_time=0.06915004085749388 nufft_with_grad_time/nufft_time=13.257346023606337
nufft_time=0.004781747702509165 nufft_with_grad_time=0.06479286774992943 nufft_with_grad_time/nufft_time=13.550038977575113
nufft_time=0.004827893804758787 nufft_with_grad_time=0.06621506763622165 nufft_with_grad_time/nufft_time=13.715104414880539
nufft_time=0.004811902996152639 nufft_with_grad_time=0.06634514313191175 nufft_with_grad_time/nufft_time=13.787714171494741
nufft_time=0.004738332703709602 nufft_with_grad_time=0.06613345304504037 nufft_with_grad_time/nufft_time=13.957114702660078
nufft_time=0.005355316214263439 nufft_with_grad_time=0.06647736439481378 nufft_with_grad_time/nufft_time=12.413340638552182
nufft_time=0.004664996173232794 nufft_with_grad_time=0.06625903397798538 nufft_with_grad_time/nufft_time=14.203448731249132
nufft_time=0.004705913830548525 nufft_with_grad_time=0.06623794976621866 nufft_with_grad_time/nufft_time=14.075470174620243
nufft_time=0.004704104270786047 nufft_with_grad_time=0.06453194282948971 nufft_with_grad_time/nufft_time=13.718221177675243
nufft_time=0.004726390819996595 nufft_with_grad_time=0.06554532889276743 nufft_with_grad_time/nufft_time=13.867945201538506

MRE:

import time
import jax
import numpy as np
import jax.numpy as jnp
from jax_finufft import nufft2

jax.config.update("jax_enable_x64", True)

pupil_npix = 513
n_warmup = 2
n_trials = 10
n_pupils = 50
pupil_npix = 128
fov_pix = 128

# construct pupil
xs = np.linspace(-0.5, 0.5, num=pupil_npix)
xx, yy = np.meshgrid(xs, xs)
rr = np.hypot(xx, yy)
pupil = (rr < 0.5).astype(np.complex64)
pupil_cube = jnp.repeat(pupil[jnp.newaxis, :, :], n_pupils, axis=0).astype(jnp.complex128)

# construct spatial frequency evaluation points
delta_lamd_per_pix = 0.5
idx_pix = delta_lamd_per_pix * (jnp.arange(fov_pix) - fov_pix / 2 + 0.5)
UVs = 2 * jnp.pi * (idx_pix / pupil_npix)
UU, VV = jnp.meshgrid(UVs, UVs)

@jax.jit
def map_nufft_over_pupil_cube(pupil):
    psfs = jax.vmap(nufft2, in_axes=[0, None, None])(
        pupil,
        UU.flatten(), 
        VV.flatten()
    ).reshape(n_pupils, fov_pix, fov_pix)
    return jnp.max((psfs**2).real)

for i in range(n_warmup):
    v, g = jax.value_and_grad(map_nufft_over_pupil_cube)(pupil_cube)
    v.block_until_ready()
    v = map_nufft_over_pupil_cube(pupil_cube)
    v.block_until_ready()

for i in range(n_trials):
    start = time.perf_counter()
    v = map_nufft_over_pupil_cube(pupil_cube)
    v.block_until_ready()
    nufft_time = time.perf_counter() - start

    start = time.perf_counter()
    v, g = jax.value_and_grad(map_nufft_over_pupil_cube)(pupil_cube)
    v.block_until_ready()
    nufft_with_grad_time = time.perf_counter() - start

    print(f"{nufft_time=} {nufft_with_grad_time=} {nufft_with_grad_time/nufft_time=}")
dfm commented 6 months ago

I don't have a GPU to test this on, but I can't reproduce this behavior on my CPU.

But first, one issue here is that you'll want to replace the calls to:

jax.value_and_grad(map_nufft_over_pupil_cube)

with a pre-compiled version:

value_and_grad = jax.jit(jax.value_and_grad(map_nufft_over_pupil_cube))
Updated code ```python import time import jax import numpy as np import jax.numpy as jnp from jax_finufft import nufft2 jax.config.update("jax_enable_x64", True) pupil_npix = 513 n_warmup = 2 n_trials = 10 n_pupils = 50 pupil_npix = 128 fov_pix = 128 # construct pupil xs = np.linspace(-0.5, 0.5, num=pupil_npix) xx, yy = np.meshgrid(xs, xs) rr = np.hypot(xx, yy) pupil = (rr < 0.5).astype(np.complex64) pupil_cube = jnp.repeat(pupil[jnp.newaxis, :, :], n_pupils, axis=0).astype(jnp.complex128) # construct spatial frequency evaluation points delta_lamd_per_pix = 0.5 idx_pix = delta_lamd_per_pix * (jnp.arange(fov_pix) - fov_pix / 2 + 0.5) UVs = 2 * jnp.pi * (idx_pix / pupil_npix) UU, VV = jnp.meshgrid(UVs, UVs) @jax.jit def map_nufft_over_pupil_cube(pupil): psfs = jax.vmap(nufft2, in_axes=[0, None, None])( pupil, UU.flatten(), VV.flatten() ).reshape(n_pupils, fov_pix, fov_pix) return jnp.max((psfs**2).real) value_and_grad = jax.jit(jax.value_and_grad(map_nufft_over_pupil_cube)) for i in range(n_warmup): v, g = value_and_grad(pupil_cube) v.block_until_ready() v = map_nufft_over_pupil_cube(pupil_cube) v.block_until_ready() for i in range(n_trials): start = time.perf_counter() v = map_nufft_over_pupil_cube(pupil_cube) v.block_until_ready() nufft_time = time.perf_counter() - start start = time.perf_counter() v, g = value_and_grad(pupil_cube) v.block_until_ready() nufft_with_grad_time = time.perf_counter() - start print(f"{nufft_time=} {nufft_with_grad_time=} {nufft_with_grad_time/nufft_time=}") ```

Otherwise, I expect you'll be dominated by tracing/compilation time. See what happens if you do that!

For reference, the timing on my machine gives:

398469537 nufft_with_grad_time=0.24668104195734486 nufft_with_grad_time/nufft_time=1.78022185899258
nufft_time=0.13608274998841807 nufft_with_grad_time=0.2466957090073265 nufft_with_grad_time/nufft_time=1.8128360062412219
nufft_time=0.13956649997271597 nufft_with_grad_time=0.24427195801399648 nufft_with_grad_time/nufft_time=1.750219128958235
nufft_time=0.13765895902179182 nufft_with_grad_time=0.24193362501682714 nufft_with_grad_time/nufft_time=1.757485504292738
nufft_time=0.13943391700740904 nufft_with_grad_time=0.24902837502304465 nufft_with_grad_time/nufft_time=1.785995691491706
nufft_time=0.13630358397495002 nufft_with_grad_time=0.2867940000141971 nufft_with_grad_time/nufft_time=2.104082604804466
nufft_time=0.14015749999089167 nufft_with_grad_time=0.2769365839776583 nufft_with_grad_time/nufft_time=1.9758955745904103
nufft_time=0.13771983300102875 nufft_with_grad_time=0.2544636250240728 nufft_with_grad_time/nufft_time=1.8476904849439661
nufft_time=0.15944166702684015 nufft_with_grad_time=0.30837995803449303 nufft_with_grad_time/nufft_time=1.9341240203075702
nufft_time=0.16091104096267372 nufft_with_grad_time=0.2543251250172034 nufft_with_grad_time/nufft_time=1.5805324699639398

So the value and grad costs about 2x the value, which is to be expected.

Perhaps @lgarrison can run some benchmarks too if my suggestion doesn't solve your problem!

joseph-long commented 6 months ago

Hmm. Well, I restructured the code to move the @jax.jit outside of the value_and_grad call, but I don't see any difference in timings. This is on the GPU in my FI workstation. I can reproduce the ratio you see on the CPU, though.

Here's the modified code:

import time
import jax
import numpy as np
import jax.numpy as jnp
from jax_finufft import nufft2

jax.config.update("jax_enable_x64", True)

pupil_npix = 513
n_warmup = 2
n_trials = 10
n_pupils = 50
pupil_npix = 128
fov_pix = 128

# construct pupil
xs = np.linspace(-0.5, 0.5, num=pupil_npix)
xx, yy = np.meshgrid(xs, xs)
rr = np.hypot(xx, yy)
pupil = (rr < 0.5).astype(np.complex64)
pupil_cube = jnp.repeat(pupil[jnp.newaxis, :, :], n_pupils, axis=0).astype(jnp.complex128)

# construct spatial frequency evaluation points
delta_lamd_per_pix = 0.5
idx_pix = delta_lamd_per_pix * (jnp.arange(fov_pix) - fov_pix / 2 + 0.5)
UVs = 2 * jnp.pi * (idx_pix / pupil_npix)
UU, VV = jnp.meshgrid(UVs, UVs)

def map_nufft_over_pupil_cube(pupil):
    psfs = jax.vmap(nufft2, in_axes=[0, None, None])(
        pupil,
        UU.flatten(), 
        VV.flatten()
    ).reshape(n_pupils, fov_pix, fov_pix)
    return jnp.max((psfs**2).real)

@jax.jit
def without_grad(pupil_cube):
    v = map_nufft_over_pupil_cube(pupil_cube)
    return v

@jax.jit
def with_grad(pupil_cube):
    v, g = jax.value_and_grad(map_nufft_over_pupil_cube)(pupil_cube)
    return v, g

for i in range(n_warmup):
    v, g = with_grad(pupil_cube)
    v.block_until_ready()
    v = without_grad(pupil_cube)
    v.block_until_ready()

for i in range(n_trials):
    start = time.perf_counter()
    v = without_grad(pupil_cube)
    v.block_until_ready()
    nufft_time = time.perf_counter() - start

    start = time.perf_counter()
    v, g = with_grad(pupil_cube)
    v.block_until_ready()
    nufft_with_grad_time = time.perf_counter() - start

    print(f"{nufft_time=} {nufft_with_grad_time=} {nufft_with_grad_time/nufft_time=}")
dfm commented 6 months ago

Interesting! Maybe there's some memory transfer or re-ordering issues that are specific to GPU. I'll try to dig a little.

dfm commented 6 months ago

On my workstation with a crappy GPU I get slower runtimes for everything, but the extra factor for value and grad is 4x.

To help diagnose the issues, I started looking at the jaxpr for these operations:

without grad ``` { lambda ; a:c128[50,128,128]. let b:f64[] = pjit[ name=without_grad jaxpr={ lambda c:f64[16384] d:f64[16384]; e:c128[50,128,128]. let f:c128[50,16384] = pjit[ name=nufft2 jaxpr={ lambda ; g:c128[50,128,128] h:f64[16384] i:f64[16384]. let j:f64[1,16384] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 16384) ] h k:f64[1,16384] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 16384) ] i l:c128[1,50,128,128] = reshape[ dimensions=None new_sizes=(1, 50, 128, 128) ] g m:c128[1,50,16384] = nufft2[ eps=1e-06 iflag=-1 output_shape=None ] l j k n:c128[50,16384] = reshape[ dimensions=None new_sizes=(50, 16384) ] m in (n,) } ] e c d o:c128[50,16384] = integer_pow[y=2] f p:f64[50,16384] = real o q:f64[] = reduce_max[axes=(0, 1)] p in (q,) } ] a in (b,) } ```
with grad ``` { lambda ; a:c128[50,128,128]. let b:f64[] c:c128[50,128,128] = pjit[ name=with_grad jaxpr={ lambda d:f64[16384] e:f64[16384]; f:c128[50,128,128]. let g:c128[50,16384] h:f64[1,16384] i:f64[1,16384] = pjit[ name=nufft2 jaxpr={ lambda ; j:c128[50,128,128] k:f64[16384] l:f64[16384]. let m:f64[1,16384] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 16384) ] k n:f64[1,16384] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 16384) ] l o:c128[1,50,128,128] = reshape[ dimensions=None new_sizes=(1, 50, 128, 128) ] j p:c128[1,50,16384] = nufft2[ eps=1e-06 iflag=-1 output_shape=None ] o m n q:c128[50,16384] = reshape[ dimensions=None new_sizes=(50, 16384) ] p in (q, m, n) } ] f d e r:c128[50,16384] = integer_pow[y=2] g s:c128[50,16384] = integer_pow[y=1] g t:c128[50,16384] = mul (2+0j) s u:f64[50,16384] = real r v:f64[] = reduce_max[axes=(0, 1)] u w:f64[1,1] = reshape[dimensions=None new_sizes=(1, 1)] v x:bool[50,16384] = eq u w y:f64[50,16384] = convert_element_type[ new_dtype=float64 weak_type=False ] x z:f64[] = reduce_sum[axes=(0, 1)] y ba:f64[] = div 1.0 z bb:f64[50,16384] = broadcast_in_dim[ broadcast_dimensions=() shape=(50, 16384) ] ba bc:f64[50,16384] = mul bb y bd:c128[50,16384] = complex bc 0.0 be:c128[50,16384] = mul bd t bf:c128[50,128,128] = pjit[ name=nufft2 jaxpr={ lambda ; bg:f64[1,16384] bh:f64[1,16384] bi:c128[50,16384]. let bj:c128[1,50,16384] = reshape[ dimensions=None new_sizes=(1, 50, 16384) ] bi bk:c128[1,50,128,128] = pjit[ name=nufft1 jaxpr={ lambda ; bl:c128[1,50,16384] bm:f64[1,16384] bn:f64[1,16384]. let bo:f64[1,1,16384] = broadcast_in_dim[ broadcast_dimensions=(0, 2) shape=(1, 1, 16384) ] bm bp:f64[1,1,16384] = broadcast_in_dim[ broadcast_dimensions=(0, 2) shape=(1, 1, 16384) ] bn bq:f64[1,16384] = reshape[ dimensions=None new_sizes=(1, 16384) ] bo br:f64[1,16384] = reshape[ dimensions=None new_sizes=(1, 16384) ] bp bs:c128[1,50,128,128] = nufft1[ eps=1e-06 iflag=-1 output_shape=[128 128] ] bl bq br in (bs,) } ] bj bg bh bt:c128[50,128,128] = reshape[ dimensions=None new_sizes=(50, 128, 128) ] bk in (bt,) } ] h i be in (v, bf) } ] a in (b, c) } ```

Some notes about this. As expected, computing the value requires one nufft2:

                      bc:c128[1,50,16384] = nufft2[
                        eps=1e-06
                        iflag=-1
                        output_shape=None
                      ] z ba bb

and computing the value and grad requires the same nufft2, and a nufft1 with the following signature:

                            cu:c128[1,50,128,128] = nufft1[
                              eps=1e-06
                              iflag=-1
                              output_shape=[128 128]
                            ] cn cs ct

which seems about right.

That being said, there's a heck of a lot of reshaping, transposing, etc. that might be leading to issues!

Edited with slightly simpler jaxprs.

dfm commented 6 months ago

I was able to further isolate the issue by adding the following benchmark:

from jax_finufft import nufft2, nufft1

@jax.jit
def func(x):
    return nufft2(x, UU, VV)

result = jax.block_until_ready(func(pupil_cube))
start = time.perf_counter()
for i in range(n_trials):
    jax.block_until_ready(func(pupil_cube))
print(time.perf_counter() - start)

@jax.jit
def func(x):
    return nufft1((pupil_npix, pupil_npix), x, UU, VV)

jax.block_until_ready(func(result))
start = time.perf_counter()
for i in range(n_trials):
    jax.block_until_ready(func(result))
print(time.perf_counter() - start)

The first number printed gives the cost of doing a single nufft2, and the second gives the cost of computing the nufft1 which is required to backprop through the initial nufft2. On some GPUs the nufft1 seems to be a factor of ~20 slower than the nufft2. We're not sure if that should be expected!

lgarrison commented 6 months ago

I've been looking into this, and what it seems to come down to is that the performance parameters just need to be tuned for the nufft1. Setting gpu_method=1 (NU-points driven instead of shared-memory driven) makes almost the whole discrepancy go away (from 20x to 1.3x). It's possible that other tuning parameters should be changed, too, but I think this means that we should implement #66 sooner rather than later!

Another question is whether cufinufft should be able to choose better tuning parameters automatically. I don't know the answer to this, we could ask in the finufft GH.

Here's the script I was using to play around with this, using the native cufinufft interface and not jax-finufft:

Script ```python import time import cufinufft import cupy as cp pupil_npix = 128 fov_npix = 128 eps = 1e-6 # 1e-5 is 5x faster with method 2, and 3x slower with method 1 plan_kwargs = {} # dict(gpu_method=1) # method 1 for type 1 is 15-20x faster pupil = cp.ones((pupil_npix, pupil_npix), dtype=cp.complex128) # this affects the filling of the UV plane, also affecting performance delta_lamd_per_pix = 0.5 idx_pix = delta_lamd_per_pix * (cp.arange(fov_npix) - fov_npix / 2 + 0.5) UVs = 2 * cp.pi * (idx_pix / pupil_npix) UU, VV = cp.meshgrid(UVs, UVs) UU = UU.flatten() VV = VV.flatten() plan2 = cufinufft.Plan(2, pupil.shape, n_trans=1, eps=eps, dtype='complex128', **plan_kwargs) plan2.setpts(UU, VV) result2 = plan2.execute(pupil) t = -time.perf_counter() for i in range(nrep:=1000): plan2.execute(pupil) t += time.perf_counter() print(f"nufft2d2: {t:.4g} s") plan1 = cufinufft.Plan(1, pupil.shape, n_trans=1, eps=eps, dtype='complex128', **plan_kwargs) plan1.setpts(UU, VV) result1 = plan1.execute(result2) t = -time.perf_counter() for i in range(nrep): plan1.execute(result2) t += time.perf_counter() print(f"nufft2d1: {t:.4g} s") ```
dfm commented 6 months ago

Interesting - thanks for looking into this, @lgarrison! I'll try to take a first stab at #66 (and think about how that'll play with differentiation, a point that didn't come up over there yet!) later this week.

lgarrison commented 6 months ago

and think about how that'll play with differentiation

Oh, right; if nufft2 prefers gpu_method=2 (which it seems to), and nufft1 prefers gpu_method=1, then we need a place to set the parameters for each method separately. Tricky!