flatironinstitute / jax-finufft

JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
Apache License 2.0
77 stars 2 forks source link

CUDA or type error when using nufft autoregressively #71

Closed OmerRochman closed 6 months ago

OmerRochman commented 6 months ago

Hey!

Weird error I can't figure out, but I have a small repro .I played with it and changing different things and can't pinpoint the error. The only thing that works is uncommenting the xy = jax.random line (i.e. creating a new array that has not been "touched" by f). Not even explicitly going to numpy works, as can bee seen, I've tried different combinations like jitting test_fn, changing type, going to numpy, deleting the buffers and whatnot. If you remove the f.real you get a different type error, it triggers the assertion that it f should be single or double (even though it is explicitly converted to a complex type). If you hide CUDA from jax it still fails but gives no error (the jupyer kernel just dies).

Addendum: the stop_gradient line makes it run 6 times before crashing with the same error, if you remove the zero array and just use stop_gradient it fails on the 2nd iteration as always.

Important: both errors happen after the first loop, that is, the code runs fine once but fails after xy is updated using out.

I'm happy to provide any more info, cheers!

#import os
#os.environ['CUDA_VISIBLE_DEVICES'] = ''

import jax
import jax.numpy as jnp
from jax_finufft import nufft2, nufft1
import numpy as np

def test_fn(f, x, y):
    size = (64, 64)
    f = jax.vmap(nufft1, in_axes=(None, -1, None, None), out_axes=-1)(size, f.astype(jnp.complex64), x, y)
    f = jax.vmap(nufft2, in_axes=(-1, None, None), out_axes=-1)(f, x, y).real

    return f

f = jax.random.normal(jax.random.key(42), (1000, 2))
xy = jax.random.normal(jax.random.key(43), (1000, 2))

for i in range(10):
    print(i)

    out = test_fn(f, *xy.T)
    xy = np.asarray(out) + np.asarray(xy)
    # xy = jax.random.normal(jax.random.key(43), (1000, 2))
    # xy = jnp.zeros_like(xy) * jax.lax.stop_gradient(xy) # with + it fails on the 2nd iteration, with * it fails on the 6th

    f = out

The error

CUDA error at /home/omer/projects/NeuralMPM/jax-finufft/vendor/finufft/src/cuda/memtransfer_wrapper.cu:365 code=700(cudaErrorIllegalAddress) "d_plan->fw" 
CUDA error at /home/omer/projects/NeuralMPM/jax-finufft/vendor/finufft/src/cuda/memtransfer_wrapper.cu:366 code=700(cudaErrorIllegalAddress) "d_plan->fwkerhalf1" 
CUDA error at /home/omer/projects/NeuralMPM/jax-finufft/vendor/finufft/src/cuda/memtransfer_wrapper.cu:367 code=700(cudaErrorIllegalAddress) "d_plan->fwkerhalf2" 
CUDA error at /home/omer/projects/NeuralMPM/jax-finufft/vendor/finufft/src/cuda/memtransfer_wrapper.cu:370 code=700(cudaErrorIllegalAddress) "d_plan->idxnupts" 
CUDA error at /home/omer/projects/NeuralMPM/jax-finufft/vendor/finufft/src/cuda/memtransfer_wrapper.cu:371 code=700(cudaErrorIllegalAddress) "d_plan->sortidx" 
CUDA error at /home/omer/projects/NeuralMPM/jax-finufft/vendor/finufft/src/cuda/memtransfer_wrapper.cu:373 code=700(cudaErrorIllegalAddress) "d_plan->binsize" 
CUDA error at /home/omer/projects/NeuralMPM/jax-finufft/vendor/finufft/src/cuda/memtransfer_wrapper.cu:374 code=700(cudaErrorIllegalAddress) "d_plan->binstartpts" 

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 19
     14 xy = jax.random.normal(jax.random.key(43), (1000, 2))
     16 for i in range(10):
---> 19     out = test_fn(f, *xy.T)
     20     xy = np.asarray(out) + np.asarray(xy)
     21     # xy = jax.random.normal(jax.random.key(43), (1000, 2))

Cell In[2], line 9, in test_fn(f, x, y)
      7 size = (64, 64)
      8 f = jax.vmap(nufft1, in_axes=(None, -1, None, None), out_axes=-1)(size, f.astype(jnp.complex64), x, y)
----> 9 f = jax.vmap(nufft2, in_axes=(-1, None, None), out_axes=-1)(f, x, y).real
     11 return f

    [... skipping hidden 17 frame]

File ~/micromamba/envs/mpm/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:1209, in ExecuteReplicated.__call__(self, *args)
   1207   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1208 else:
-> 1209   results = self.xla_executable.execute_sharded(input_bufs)
   1210 if dispatch.needs_check_special():
   1211   out_arrays = results.disassemble_into_single_device_arrays()

RuntimeError: an illegal memory access was encountered
lgarrison commented 6 months ago

Thanks for making a minimal reproducer! In this case, the issue is that finufft is being passed non-uniform points outside of the valid input domain, which is [-3pi, 3pi]. The CPU finufft library by default gives a warning about this; the GPU library doesn't have this option, unfortunately.

You can apply % (2 * np.pi) to the nonuniform points to fix this. The reason finufft doesn't do this itself is performance: range reduction can be expensive, so it's left to the user to decide how to deal with this. It's certainly a bad user experience, though.

@dfm Not sure if we should try to implement a domain check in jax-finufft, or encourage upstream (cu)finufft to add this option?

OmerRochman commented 6 months ago

Oh thanks! It was that indeed. In the actual code I had "soft" normalization which made it fit into the right range only sometimes, making the problem harder to catch, which is probably something quite common. It was pretty maddening to deal with because the cuda error gives no information, which of course is not your fault. imho a temporary warning (until it is fixed upstream) would be nice :P Thanks again!

Edit: I assume its a typo in your comment and the range is [-2pi, 2pi], right?

dfm commented 6 months ago

Not sure if we should try to implement a domain check in jax-finufft, or encourage upstream (cu)finufft to add this option?

AFAIK, JAX doesn't really have any good ways to do runtime checking of data values (the checkify transform exists, but doesn't provide a very user friendly API) so our best option would be to do the range reduction which, like you said, could have significant performance implications. I'd recommend that perhaps we could just try to document this requirement more clearly.

(I believe the correct range is (-3pi, 3pi) rather than (-2pi, 2pi). See the first note on this page, for example.)