flatironinstitute / jax-finufft

JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
Apache License 2.0
77 stars 2 forks source link

MLIR translation rule for primitive 'nufft2' not found for platform cuda #64

Closed AaronParsons closed 6 months ago

AaronParsons commented 6 months ago

Thank you for your work on jax-finufft. I am successfully using it targeting my cpu, but am having issues with it when using my gpu. Bascially, I get a "NotImplementedError: MLIR translation rule for primitive 'nufft2' not found for platform cuda" as soon as jax tries to do autodifferentiation.

As an example: ` $ pytest ops_test.py::test_nufft1_forward[2-False-50-75--1] ========================== test session starts =========================== platform linux -- Python 3.10.12, pytest-7.4.2, pluggy-1.3.0 rootdir: /jax-finufft/tests plugins: anyio-3.6.2 collected 1 item

ops_test.py F [100%]

================================ FAILURES ================================ _____ test_nufft1_forward[2-False-50-75--1] __

@pytest.mark.parametrize(
    "ndim, x64, num_nonnuniform, num_uniform, iflag",
    product([1, 2, 3], [False, True], [50], [75], [-1, 1]),
)
def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag):
    if ndim == 1 and jax.default_backend() != "cpu":
        pytest.skip("1D transforms not implemented on GPU")

    random = np.random.default_rng(657)

    eps = 1e-10 if x64 else 1e-7
    dtype = np.double if x64 else np.single
    cdtype = np.cdouble if x64 else np.csingle

    num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim))
    ks = [np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) for n in num_uniform]

    x = random.uniform(-np.pi, np.pi, size=(ndim, num_nonnuniform)).astype(dtype)
    c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform)
    c = c.astype(cdtype)
    f_expect = np.zeros(num_uniform, dtype=cdtype)
    for coords in product(*map(range, num_uniform)):
        k_vec = np.array([k[n] for (n, k) in zip(coords, ks)])
        f_expect[coords] = np.sum(c * np.exp(iflag * 1j * np.dot(k_vec, x)))

    with jax.experimental.enable_x64(x64):
      f_calc = nufft1(num_uniform, c, *x, eps=eps, iflag=iflag)

ops_test.py:45:


@partial(jit, static_argnums=(0,), static_argnames=("iflag", "eps"))
def nufft1(output_shape, source, *points, iflag=1, eps=1e-6):
    iflag = int(iflag)
    eps = float(eps)
    ndim = len(points)
    if not 1 <= ndim <= 3:
        raise ValueError("Only 1-, 2-, and 3-dimensions are supported")

    # Support passing a scalar output_shape
    output_shape = np.atleast_1d(output_shape).astype(np.int64)
    if output_shape.shape != (ndim,):
        raise ValueError(f"output_shape must have shape: ({ndim},)")

    # Handle broadcasting and reshaping of inputs
    index, source, *points = shapes.broadcast_and_flatten_inputs(
        output_shape, source, *points
    )

    # Execute the transform primitive
  result = nufft1_p.bind(

source, *points, output_shape=output_shape, iflag=iflag, eps=eps ) E jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'nufft1' not found for platform cuda E
E The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. E
E --------------------

../../../../.local/lib/python3.10/site-packages/jax_finufft/ops.py:33: JaxStackTraceBeforeTransformation

The above exception was the direct cause of the following exception:

ndim = 2, x64 = False, num_nonnuniform = 50, num_uniform = (37, 42) iflag = -1

@pytest.mark.parametrize(
    "ndim, x64, num_nonnuniform, num_uniform, iflag",
    product([1, 2, 3], [False, True], [50], [75], [-1, 1]),
)
def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag):
    if ndim == 1 and jax.default_backend() != "cpu":
        pytest.skip("1D transforms not implemented on GPU")

    random = np.random.default_rng(657)

    eps = 1e-10 if x64 else 1e-7
    dtype = np.double if x64 else np.single
    cdtype = np.cdouble if x64 else np.csingle

    num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim))
    ks = [np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) for n in num_uniform]

    x = random.uniform(-np.pi, np.pi, size=(ndim, num_nonnuniform)).astype(dtype)
    c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform)
    c = c.astype(cdtype)
    f_expect = np.zeros(num_uniform, dtype=cdtype)
    for coords in product(*map(range, num_uniform)):
        k_vec = np.array([k[n] for (n, k) in zip(coords, ks)])
        f_expect[coords] = np.sum(c * np.exp(iflag * 1j * np.dot(k_vec, x)))

    with jax.experimental.enable_x64(x64):
      f_calc = nufft1(num_uniform, c, *x, eps=eps, iflag=iflag)

ops_test.py:45:


../../../../.local/lib/python3.10/site-packages/jax/_src/pjit.py:238: in cache_miss outs, out_flat, out_tree, args_flat = _python_pjit_helper( ../../../../.local/lib/python3.10/site-packages/jax/_src/pjit.py:185: in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, params) ../../../../.local/lib/python3.10/site-packages/jax/_src/core.py:2592: in bind return self.bind_with_trace(top_trace, args, params) ../../../../.local/lib/python3.10/site-packages/jax/_src/core.py:363: in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) ../../../../.local/lib/python3.10/site-packages/jax/_src/core.py:817: in process_primitive return primitive.impl(*tracers, *params) ../../../../.local/lib/python3.10/site-packages/jax/_src/pjit.py:1229: in _pjit_call_impl compiled = _pjit_lower( ../../../../.local/lib/python3.10/site-packages/jax/_src/pjit.py:1315: in _pjit_lower return _pjit_lower_cached(jaxpr, in_shardings, out_shardings, args, kwargs) ../../../../.local/lib/python3.10/site-packages/jax/_src/pjit.py:1374: in _pjit_lower_cached return pxla.lower_sharding_computation( ../../../../.local/lib/python3.10/site-packages/jax/_src/profiler.py:314: in wrapper return func(*args, **kwargs) ../../../../.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:2551: in lower_sharding_computation lowering_result = mlir.lower_jaxpr_to_module( ../../../../.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:742: in lower_jaxpr_to_module lower_jaxpr_to_fun( ../../../../.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1037: in lower_jaxpr_to_fun out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),


ctx = ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize..Context object at 0x7f662850d170>, module=<jax...annel_iterator=count(1), host_callbacks=[], dim_vars=(), cached_primitive_lowerings={}, cached_call_jaxpr_lowerings={}) jaxpr = { lambda ; a:c64[50] b:f32[50] c:f32[50]. let d:c64[1,1,50] = reshape[dimensions=None new_sizes=(1, 1, 50)] a ...1e-07 iflag=-1 output_shape=[37 42]] d e f h:c64[37,42] = reshape[dimensions=None new_sizes=(37, 42)] g in (h,) } tokens = <jax._src.interpreters.mlir.TokenSet object at 0x7f66284e6c50> consts = [], dim_var_values = [] args = ([<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7f66285118f0>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7f6628511930>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7f66285119b0>]) aval = <function jaxpr_subcomp..aval at 0x7f66284db880> write = <function jaxpr_subcomp..write at 0x7f66284db9a0>

def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
                  tokens: TokenSet,
                  consts: Sequence[Sequence[ir.Value]],
                  *args: Sequence[ir.Value],
                  dim_var_values: Sequence[ir.Value]
                  ) -> Tuple[Sequence[Sequence[ir.Value]], TokenSet]:
  """Lowers a jaxpr into MLIR, inlined into an existing function.

  Assumes that an MLIR context, location, and insertion point are set.

  dim_var_values: the list of dimension variables values in the current
    IR function, in the order of ctx.dim_vars.
  """
  assert ctx.platform != "gpu"
  def read(v: core.Atom) -> Sequence[ir.Value]:
    if type(v) is core.Literal:
      return ir_constants(v.val, canonicalize_types=True)
    else:
      assert isinstance(v, core.Var)
      return env[v]

  def aval(v: core.Atom) -> core.AbstractValue:
    if type(v) is core.Literal:
      return xla.abstractify(v.val)
    else:
      return v.aval

  def write(v: core.Var, node: Sequence[ir.Value]):
    assert node is not None
    env[v] = tuple(node)

  env: Dict[core.Var, Tuple[ir.Value, ...]] = {}

  assert len(args) == len(jaxpr.invars), (jaxpr, args)
  assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
  assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
  assert len(ctx.dim_vars) == len(dim_var_values), (ctx.dim_vars, dim_var_values)
  map(write, jaxpr.constvars, consts)
  map(write, jaxpr.invars, args)
  for eqn in jaxpr.eqns:
    in_nodes = map(read, eqn.invars)
    assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
    source_info = eqn.source_info.replace(
        name_stack=ctx.name_stack + eqn.source_info.name_stack)
    loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
                                   ctx.name_stack)
    with source_info_util.user_context(eqn.source_info.traceback), loc:
      if eqn.primitive in _platform_specific_lowerings[ctx.platform]:
        rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
      elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
        rule = xla_fallback_lowering(eqn.primitive)
      elif eqn.primitive in _lowerings:
        rule = _lowerings[eqn.primitive]
      elif eqn.primitive in xla._translations:
        rule = xla_fallback_lowering(eqn.primitive)
      else:
      raise NotImplementedError(

f"MLIR translation rule for primitive '{eqn.primitive.name}' not " f"found for platform {ctx.platform}") E jax._src.traceback_util.UnfilteredStackTrace: NotImplementedError: MLIR translation rule for primitive 'nufft1' not found for platform cuda E
E The stack trace below excludes JAX-internal frames. E The preceding is the original exception that occurred, unmodified. E
E --------------------

../../../../.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1155: UnfilteredStackTrace

The above exception was the direct cause of the following exception:

ndim = 2, x64 = False, num_nonnuniform = 50, num_uniform = (37, 42) iflag = -1

@pytest.mark.parametrize(
    "ndim, x64, num_nonnuniform, num_uniform, iflag",
    product([1, 2, 3], [False, True], [50], [75], [-1, 1]),
)
def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag):
    if ndim == 1 and jax.default_backend() != "cpu":
        pytest.skip("1D transforms not implemented on GPU")

    random = np.random.default_rng(657)

    eps = 1e-10 if x64 else 1e-7
    dtype = np.double if x64 else np.single
    cdtype = np.cdouble if x64 else np.csingle

    num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim))
    ks = [np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) for n in num_uniform]

    x = random.uniform(-np.pi, np.pi, size=(ndim, num_nonnuniform)).astype(dtype)
    c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform)
    c = c.astype(cdtype)
    f_expect = np.zeros(num_uniform, dtype=cdtype)
    for coords in product(*map(range, num_uniform)):
        k_vec = np.array([k[n] for (n, k) in zip(coords, ks)])
        f_expect[coords] = np.sum(c * np.exp(iflag * 1j * np.dot(k_vec, x)))

    with jax.experimental.enable_x64(x64):
      f_calc = nufft1(num_uniform, c, *x, eps=eps, iflag=iflag)

E NotImplementedError: MLIR translation rule for primitive 'nufft1' not found for platform cuda

ops_test.py:45: NotImplementedError ======================== short test summary info ========================= FAILED ops_test.py::test_nufft1_forward[2-False-50-75--1] - NotImplementedError: MLIR translation rule for primitive 'nufft1' not... =========================== 1 failed in 0.86s ============================

`

AaronParsons commented 6 months ago

For context, I am on a Thinkpad P1 Gen 2 running Ubuntu 22.04 with CUDA 12.3. ` +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 545.29.06 Driver Version: 545.29.06 CUDA Version: 12.3 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 Quadro T2000 Off | 00000000:01:00.0 Off | N/A | | N/A 44C P8 4W / 30W | 3060MiB / 4096MiB | 0% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | 0 N/A N/A 1499 G /usr/lib/xorg/Xorg 4MiB | | 0 N/A N/A 21281 C /usr/bin/python3 3052MiB | +---------------------------------------------------------------------------------------+ `

I have installed finufft (from source, built following [these] instructions, which were incomplete and did not include installation and path updates)

  1. mkdir build
  2. cd build
  3. cmake -D FINUFFT_USE_CUDA=ON -D CMAKE_CUDA_ARCHITECTURES=75 ..
  4. cmake --build .
  5. cmake --install . # this was missing from finufft instructions
  6. added /usr/local/lib to LD_LIBRARY_PATH, also missing from finufft instructions
  7. pip install python/cufinufft
  8. `pytest --framework=pycuda python/cufinufft/tests ========================== test session starts =========================== platform linux -- Python 3.10.12, pytest-7.4.2, pluggy-1.3.0 plugins: anyio-3.6.2 collected 666 items

python/cufinufft/tests/test_array_ordering.py . [ 0%] python/cufinufft/tests/test_basic.py ............................. [ 4%] .................................................................. [ 14%] .................................................................. [ 24%] ........................................................ [ 32%] python/cufinufft/tests/test_error_checks.py ........ [ 33%] python/cufinufft/tests/test_examples.py .ss.s.s [ 34%] python/cufinufft/tests/test_multi.py s [ 35%] python/cufinufft/tests/test_simple.py ............................ [ 39%] .................................................................. [ 49%] .................................................................. [ 59%] .................................................................. [ 69%] .................................................................. [ 78%] .................................................................. [ 88%] .................................................................. [ 98%] ........ [100%]

============================ warnings summary ============================ tests/test_array_ordering.py::test_type1ordering[pycuda] /home//.local/lib/python3.10/site-packages/pycuda/compyte/dtypes.py:120: DeprecationWarning: np.bool8 is a deprecated alias for `np.bool`. (Deprecated NumPy 1.24) reg.get_or_register_dtype("bool", np.bool8)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html =============== 661 passed, 5 skipped, 1 warning in 3.90s ================

lgarrison commented 6 months ago

An error like NotImplementedError: MLIR translation rule for primitive 'nufft1' not found for platform cuda usually means that either jax-finufft wasn't compiled with CUDA support, or that it was compiled but the extension can't be loaded (e.g. #46).

Can you share the exact commands you're running to build jax-finufft and their output (and make the pip install verbose with python -m pip install -v .)? Note that there's a section in the instructions specific to building with CUDA support. I would also double-check that LD_LIBRARY_PATH is set per the instructions.

finufft/cufinufft are bundled with jax-finufft, so you don't need to install them separately. I don't think jax-finufft has any way to see them, though, so it's probably harmless.

AaronParsons commented 6 months ago

Thanks for the reply, and sorry it's taken me this long to circle back.

Indeed, there was an environment variable passing problem that stemmed from some unclarities in the install documentation. Once we get this working, I can offer a documentation of my process.

Build and installation is resolved, but now get a straight segfault on when running tests. I'm still trying to track down where that segfault comes from, and any help would be appreciated. Here is where I am now:

$ cd jax-finufft $ nvidia-smi --query-gpu=compute_cap --format=csv,noheader 7.5 $ export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=75 -DJAX_FINUFFT_USE_CUDA=ON" $ echo $LD_LIBRARY_PATH /usr/local/cuda-12.3/lib64:/usr/local/lib: $ rm -rf build $ pip install pybind11[global] $ export CMAKE_PREFIX_PATH=/home/aparsons/.local/lib/python3.10/site-packages/pybind11/share/cmake/pybind11 $ pip install -v .

Using pip 24.0 from /home/aparsons/.local/lib/python3.10/site-packages/pip (python 3.10) Defaulting to user installation because normal site-packages is not writeable Processing /home/aparsons/projects/hera/jax-finufft Running command pip subprocess to install build dependencies Collecting pybind11>=2.6 Using cached pybind11-2.11.1-py3-none-any.whl.metadata (9.5 kB) Collecting scikit-build-core>=0.5 Using cached scikit_build_core-0.8.2-py3-none-any.whl.metadata (19 kB) Collecting exceptiongroup (from scikit-build-core>=0.5) Using cached exceptiongroup-1.2.0-py3-none-any.whl.metadata (6.6 kB) Collecting packaging>=20.9 (from scikit-build-core>=0.5) Using cached packaging-24.0-py3-none-any.whl.metadata (3.2 kB) Collecting tomli>=1.1 (from scikit-build-core>=0.5) Using cached tomli-2.0.1-py3-none-any.whl.metadata (8.9 kB) Using cached pybind11-2.11.1-py3-none-any.whl (227 kB) Using cached scikit_build_core-0.8.2-py3-none-any.whl (140 kB) Using cached packaging-24.0-py3-none-any.whl (53 kB) Using cached tomli-2.0.1-py3-none-any.whl (12 kB) Using cached exceptiongroup-1.2.0-py3-none-any.whl (16 kB) Installing collected packages: tomli, pybind11, packaging, exceptiongroup, scikit-build-core Successfully installed exceptiongroup-1.2.0 packaging-24.0 pybind11-2.11.1 scikit-build-core-0.8.2 tomli-2.0.1 Installing build dependencies ... done Running command Getting requirements to build wheel Getting requirements to build wheel ... done Running command pip subprocess to install backend dependencies Collecting pathspec Using cached pathspec-0.12.1-py3-none-any.whl.metadata (21 kB) Collecting setuptools-scm Using cached setuptools_scm-8.0.4-py3-none-any.whl.metadata (6.4 kB) Collecting pyproject_metadata Using cached pyproject_metadata-0.7.1-py3-none-any.whl.metadata (3.0 kB) Collecting ninja>=1.5 Using cached ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl.metadata (5.3 kB) Collecting packaging>=20 (from setuptools-scm) Using cached packaging-24.0-py3-none-any.whl.metadata (3.2 kB) Collecting setuptools (from setuptools-scm) Using cached setuptools-69.2.0-py3-none-any.whl.metadata (6.3 kB) Collecting typing-extensions (from setuptools-scm) Using cached typing_extensions-4.10.0-py3-none-any.whl.metadata (3.0 kB) Collecting tomli>=1 (from setuptools-scm) Using cached tomli-2.0.1-py3-none-any.whl.metadata (8.9 kB) Using cached pathspec-0.12.1-py3-none-any.whl (31 kB) Using cached setuptools_scm-8.0.4-py3-none-any.whl (42 kB) Using cached pyproject_metadata-0.7.1-py3-none-any.whl (7.4 kB) Using cached ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB) Using cached packaging-24.0-py3-none-any.whl (53 kB) Using cached tomli-2.0.1-py3-none-any.whl (12 kB) Using cached setuptools-69.2.0-py3-none-any.whl (821 kB) Using cached typing_extensions-4.10.0-py3-none-any.whl (33 kB) Installing collected packages: ninja, typing-extensions, tomli, setuptools, pathspec, packaging, setuptools-scm, pyproject_metadata Successfully installed ninja-1.11.1.1 packaging-24.0 pathspec-0.12.1 pyproject_metadata-0.7.1 setuptools-69.2.0 setuptools-scm-8.0.4 tomli-2.0.1 typing-extensions-4.10.0 Installing backend dependencies ... done Running command Preparing metadata (pyproject.toml) scikit-build-core 0.8.2 using CMake 3.22.1 (metadata_wheel) Preparing metadata (pyproject.toml) ... done Requirement already satisfied: jax in /home/aparsons/.local/lib/python3.10/site-packages (from jax-finufft==0.0.4.dev53+g049f0b9) (0.4.8) Requirement already satisfied: jaxlib in /home/aparsons/.local/lib/python3.10/site-packages (from jax-finufft==0.0.4.dev53+g049f0b9) (0.4.7+cuda12.cudnn88) Requirement already satisfied: ml-dtypes>=0.0.3 in /home/aparsons/.local/lib/python3.10/site-packages (from jax->jax-finufft==0.0.4.dev53+g049f0b9) (0.0.4) Requirement already satisfied: numpy>=1.21 in /home/aparsons/.local/lib/python3.10/site-packages (from jax->jax-finufft==0.0.4.dev53+g049f0b9) (1.25.0) Requirement already satisfied: opt-einsum in /home/aparsons/.local/lib/python3.10/site-packages (from jax->jax-finufft==0.0.4.dev53+g049f0b9) (3.3.0) Requirement already satisfied: scipy>=1.7 in /home/aparsons/.local/lib/python3.10/site-packages (from jax->jax-finufft==0.0.4.dev53+g049f0b9) (1.10.1) Building wheels for collected packages: jax-finufft Running command Building wheel for jax-finufft (pyproject.toml) scikit-build-core 0.8.2 using CMake 3.22.1 (wheel) Configuring CMake... loading initial cache file build/cp310-cp310-manylinux_2_35_x86_64/CMakeInit.txt -- The C compiler identification is GNU 11.4.0 -- The CXX compiler identification is GNU 11.4.0 -- Detecting C compiler ABI info -- Detecting C compiler ABI info - done -- Check for working C compiler: /usr/bin/cc - skipped -- Detecting C compile features -- Detecting C compile features - done -- Detecting CXX compiler ABI info -- Detecting CXX compiler ABI info - done -- Check for working CXX compiler: /usr/bin/c++ - skipped -- Detecting CXX compile features -- Detecting CXX compile features - done -- Using CMake version: 3.22.1 -- Looking for a CUDA compiler -- Looking for a CUDA compiler - /usr/local/cuda-12.3/bin/nvcc -- CUDA compiler found; compiling with GPU support -- The CUDA compiler identification is NVIDIA 12.3.107 -- Detecting CUDA compiler ABI info -- Detecting CUDA compiler ABI info - done -- Check for working CUDA compiler: /usr/local/cuda-12.3/bin/nvcc - skipped -- Detecting CUDA compile features -- Detecting CUDA compile features - done -- Downloading CPM.cmake to /home/aparsons/projects/hera/jax-finufft/build/cp310-cp310-manylinux_2_35_x86_64/cmake/CPM_0.38.0.cmake -- CPM: Adding package findfftw@ (master) -- Found PkgConfig: /usr/bin/pkg-config (found version "0.29.2") -- Found FFTW: /usr/include -- Found OpenMP_C: -fopenmp (found version "4.5") -- Found OpenMP_CXX: -fopenmp (found version "4.5") -- Found OpenMP: TRUE (found version "4.5") -- Found CUDAToolkit: /usr/local/cuda-12.3/include (found version "12.3.107") -- Looking for pthread.h -- Looking for pthread.h - found -- Performing Test CMAKE_HAVE_LIBC_PTHREAD -- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Success -- Found Threads: TRUE -- Found PythonInterp: /usr/bin/python3 (found suitable version "3.10.12", minimum required is "3.6") -- Found PythonLibs: /usr/lib/x86_64-linux-gnu/libpython3.10.so -- Found pybind11: /tmp/pip-build-env-9jy7iupp/overlay/local/lib/python3.10/dist-packages/pybind11/include (found version "2.11.1") -- Configuring done -- Generating done -- Build files have been written to: /home/aparsons/projects/hera/jax-finufft/build/cp310-cp310-manylinux_2_35_x86_64 Building project with Ninja... [1/42] Building CXX object vendor/finufft/CMakeFiles/finufft_f32.dir/fortran/finufftfort.cpp.o [2/42] Building CXX object vendor/finufft/CMakeFiles/finufft_f64.dir/fortran/finufftfort.cpp.o [3/42] Building CXX object vendor/finufft/CMakeFiles/finufft_f64.dir/src/utils.cpp.o [4/42] Building CXX object vendor/finufft/CMakeFiles/finufft_f64.dir/src/simpleinterfaces.cpp.o [5/42] Building CXX object vendor/finufft/CMakeFiles/finufft_f32.dir/src/utils.cpp.o [6/42] Building CXX object vendor/finufft/CMakeFiles/finufft_f32.dir/src/simpleinterfaces.cpp.o [7/42] Building CXX object vendor/finufft/CMakeFiles/finufft.dir/src/utils_precindep.cpp.o [8/42] Building CXX object vendor/finufft/CMakeFiles/finufft_static.dir/contrib/legendre_rule_fast.cpp.o [9/42] Building CXX object vendor/finufft/CMakeFiles/finufft.dir/contrib/legendre_rule_fast.cpp.o [10/42] Building CXX object vendor/finufft/CMakeFiles/finufft_static.dir/src/utils_precindep.cpp.o [11/42] Building CXX object vendor/finufft/src/cuda/CMakeFiles/cufinufft_common_objects.dir/utils.cpp.o [12/42] Building CXX object vendor/finufft/src/cuda/CMakeFiles/cufinufft_common_objects.dir///contrib/legendre_rule_fast.cpp.o [13/42] Building CXX object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/spreadinterp.cpp.o [14/42] Building CXX object vendor/finufft/CMakeFiles/finufft_f32.dir/src/finufft.cpp.o [15/42] Building CXX object vendor/finufft/CMakeFiles/finufft_f64.dir/src/finufft.cpp.o [16/42] Building CUDA object CMakeFiles/jax_finufft_gpu.dir/lib/kernels.cc.cu.o [17/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/1d/cufinufft1d.cu.o [18/42] Building CXX object vendor/finufft/CMakeFiles/finufft_f64.dir/src/spreadinterp.cpp.o [19/42] Building CXX object vendor/finufft/CMakeFiles/finufft_f32.dir/src/spreadinterp.cpp.o [20/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/2d/cufinufft2d.cu.o [21/42] Linking CXX static library vendor/finufft/libfinufft_static.a [22/42] Linking CXX shared library vendor/finufft/libfinufft.so [23/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/3d/cufinufft3d.cu.o [24/42] Building CXX object CMakeFiles/jax_finufft_gpu.dir/lib/jax_finufft_gpu.cc.o [25/42] Building CXX object CMakeFiles/jax_finufft_cpu.dir/lib/jax_finufft_cpu.cc.o [26/42] Linking CXX shared module jax_finufft_cpu.cpython-310-x86_64-linux-gnu.so [27/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/memtransfer_wrapper.cu.o [28/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/cufinufft.cu.o [29/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/deconvolve_wrapper.cu.o [30/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/common.cu.o [31/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_common_objects.dir/precision_independent.cu.o [32/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/1d/interp1d_wrapper.cu.o [33/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/3d/interp3d_wrapper.cu.o [34/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/2d/interp2d_wrapper.cu.o [35/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/1d/spread1d_wrapper.cu.o [36/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/2d/spread2d_wrapper.cu.o [37/42] Building CUDA object vendor/finufft/src/cuda/CMakeFiles/cufinufft_objects.dir/3d/spread3d_wrapper.cu.o [38/42] Linking CXX static library libcufinufft_static.a [39/42] Linking CXX shared library CMakeFiles/cufinufft.dir/cmake_device_link.o [40/42] Linking CXX shared library libcufinufft.so [41/42] Linking CXX shared module CMakeFiles/jax_finufft_gpu.dir/cmake_device_link.o [42/42] Linking CXX shared module jax_finufft_gpu.cpython-310-x86_64-linux-gnu.so Installing project into wheel... -- Install configuration: "Release" -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/lib/libfinufft.so -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/include/finufft.h -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/include/finufft_eitherprec.h -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/include/finufft_errors.h -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/include/finufft_opts.h -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/include/finufft_spread_opts.h -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/lib/libfinufft_static.a -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/lib/libcufinufft.so -- Set runtime path of "/tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/lib/libcufinufft.so" to "" -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/lib/libcufinufft_static.a -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/licenses/finufft/LICENSE -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/guru1d1c.c -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/simple1d1.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/threadsafe1d1.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/guru1d1f.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/simple1d1c.c -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/simple1d1cf.c -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/gurumany1d1.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/many1d1.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/simulplans1d1.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/guru2d1.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/guru1d1.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/simple2d1.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/threadsafe2d2f.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/simple1d1f.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/cuda -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/cuda/example2d1many.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/cuda/example2d2many.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/share/finufft/examples/cuda/getting_started.cpp -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/./jax_finufft_cpu.cpython-310-x86_64-linux-gnu.so -- Installing: /tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/./jax_finufft_gpu.cpython-310-x86_64-linux-gnu.so -- Set runtime path of "/tmp/tmp4qbez2ec/wheel/platlib/jax_finufft/./jax_finufft_gpu.cpython-310-x86_64-linux-gnu.so" to "" Making wheel... *** Created jax_finufft-0.0.4.dev53+g049f0b9-cp310-cp310-manylinux_2_35_x86_64.whl... Building wheel for jax-finufft (pyproject.toml) ... done Created wheel for jax-finufft: filename=jax_finufft-0.0.4.dev53+g049f0b9-cp310-cp310-manylinux_2_35_x86_64.whl size=104259500 sha256=05e3045b3c2c13f07f4fffcb9bd03c2907284eae3ba6783a98cf38e55675b813 Stored in directory: /home/aparsons/.cache/pip/wheels/b3/09/ff/f4516de6307b667b70e4cc44ff10332554e970ad70dbe14a36 Successfully built jax-finufft Installing collected packages: jax-finufft Attempting uninstall: jax-finufft Found existing installation: jax-finufft 0.0.4.dev53+g049f0b9 Uninstalling jax-finufft-0.0.4.dev53+g049f0b9: Removing file or directory /home/aparsons/.local/lib/python3.10/site-packages/jax_finufft-0.0.4.dev53+g049f0b9.dist-info/ Removing file or directory /home/aparsons/.local/lib/python3.10/site-packages/jax_finufft/ Successfully uninstalled jax-finufft-0.0.4.dev53+g049f0b9 Successfully installed jax-finufft-0.0.4.dev53+g049f0b9

$ pytest ops_test.py::test_nufft1_forward[2-False-50-75--1] ============================= test session starts ============================== platform linux -- Python 3.10.12, pytest-7.4.2, pluggy-1.3.0 rootdir: /home/aparsons/projects/hera/jax-finufft/tests plugins: anyio-3.6.2 collected 1 item

ops_test.py Fatal Python error: Segmentation fault

Current thread 0x00007fa2bd4b9740 (most recent call first): File "/home/aparsons/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1916 in call File "/home/aparsons/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314 in wrapper File "/home/aparsons/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1252 in _pjit_call_impl File "/home/aparsons/.local/lib/python3.10/site-packages/jax/_src/core.py", line 817 in process_primitive File "/home/aparsons/.local/lib/python3.10/site-packages/jax/_src/core.py", line 363 in bind_with_trace File "/home/aparsons/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2592 in bind File "/home/aparsons/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 185 in _python_pjit_helper File "/home/aparsons/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 238 in cache_miss File "/home/aparsons/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166 in reraise_with_filtered_traceback File "/home/aparsons/projects/hera/jax-finufft/tests/ops_test.py", line 45 in test_nufft1_forward File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/python.py", line 194 in pytest_pyfunc_call File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 77 in _multicall File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 115 in _hookexec File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 493 in call File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/python.py", line 1792 in runtest File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/runner.py", line 169 in pytest_runtest_call File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 77 in _multicall File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 115 in _hookexec File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 493 in call File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/runner.py", line 262 in File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/runner.py", line 341 in from_call File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/runner.py", line 261 in call_runtest_hook File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/runner.py", line 222 in call_and_report File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/runner.py", line 133 in runtestprotocol File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/runner.py", line 114 in pytest_runtest_protocol File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 77 in _multicall File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 115 in _hookexec File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 493 in call File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/main.py", line 350 in pytest_runtestloop File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 77 in _multicall File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 115 in _hookexec File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 493 in call File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/main.py", line 325 in _main File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/main.py", line 271 in wrap_session File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/main.py", line 318 in pytest_cmdline_main File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 77 in _multicall File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 115 in _hookexec File "/home/aparsons/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 493 in call File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/config/init.py", line 169 in main File "/home/aparsons/.local/lib/python3.10/site-packages/_pytest/config/init.py", line 192 in console_main File "/home/aparsons/.local/bin/pytest", line 8 in

Extension modules: jaxlib.cpu_feature_guard, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator (total: 14) Segmentation fault (core dumped)

AaronParsons commented 6 months ago

Actually, that segfault may have been from a stale GPU interface held by a notebook while running the tests in the command line. They have cleared, so I think we can call this issue resolved.

Here is the final list of instructions: $ cd jax-finufft $ nvidia-smi --query-gpu=compute_cap --format=csv,noheader 7.5 $ export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=75 -DJAX_FINUFFT_USE_CUDA=ON" $ echo $LD_LIBRARY_PATH /usr/local/cuda-12.3/lib64:/usr/local/lib: $ rm -rf build $ python -m pip install "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html $ pip install pycuda $ pip install pybind11[global] $ export CMAKE_PREFIX_PATH=/home/aparsons/.local/lib/python3.10/site-packages/pybind11/share/cmake/pybind11 $ pip install .

A major thing I missed is that cmake on all submodules is run within the pip process, and I think a source of my original problems was related to the "cmake" instructions that were linked, which implied that I needed to run some of the cmake build of submodules manually.

AaronParsons commented 6 months ago

Thanks!

lgarrison commented 6 months ago

Thanks for reporting back, glad it's working!

I think a source of my original problems was related to the "cmake" instructions that were linked, which implied that I needed to run some of the cmake build of submodules manually.

This makes sense, I can see how that would be confusing. I'll update the docs; really, there's not much reason for us to link to that page, we can just echo the instructions about querying nvidia-smi.

These three lines should definitely not be necessary, though:

$ pip install pycuda $ pip install pybind11[global] $ export CMAKE_PREFIX_PATH=/home/aparsons/.local/lib/python3.10/site-packages/pybind11/share/cmake/pybind11

If it wasn't working without these, it might be worth trying a Python venv next time to see if that helps. Probably not worth messing with your working installation right now, but just wanted to mention it in case it comes up again.