jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.62k stars 2.82k forks source link

FFI example not working as given in documentation #24131

Closed bridgesign closed 1 month ago

bridgesign commented 1 month ago

Description

I am trying to create a custom C extension for JAX and was trying out the example given in the documentation. But when I try to run the example, I get the following error:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/test/test.py", line 54, in <module>
    np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)
  File "/home/test/test.py", line 35, in rms_norm
    return jex.ffi.ffi_call(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/extend/ffi.py", line 240, in ffi_call
    results = ffi_call_p.bind(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 439, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 443, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 949, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 88, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Wrong number of attributes: expected 1 but got 2

System info (python version, jaxlib version, accelerator, etc.)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. jax: 0.4.33 jaxlib: 0.4.33 numpy: 2.1.1 python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Linux', node='027b41159048', release='6.8.0-40-generic', version='#40~22.04.3-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 30 17:30:19 UTC 2', machine='x86_64')

$ nvidia-smi Fri Oct 4 18:19:24 2024
+---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.183.01 Driver Version: 535.183.01 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA TITAN X (Pascal) Off | 00000000:01:00.0 On | N/A | | 33% 57C P0 62W / 250W | 3731MiB / 12288MiB | 7% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| +---------------------------------------------------------------------------------------+

dfm commented 1 month ago

Thanks for this report! That example is being run as part of CI so I expect there's something different in your setup. Can you share the exact C++ and Python code that you're running?

Edited to add: You might also consider checking out this version of that example which includes all the packaging details, etc.: https://github.com/jax-ml/jax/tree/main/examples/ffi

bridgesign commented 1 month ago

Thanks the example helps! I tried to copy the parts from the documentation in order.

// rms_norm.cc
#include <functional>
#include <numeric>
#include <utility>

#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

namespace ffi = xla::ffi;

#include <cmath>
#include <cstdint>

float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
  float sm = 0.0f;
  for (int64_t n = 0; n < size; ++n) {
    sm += x[n] * x[n];
  }
  float scale = 1.0f / std::sqrt(sm / float(size) + eps);
  for (int64_t n = 0; n < size; ++n) {
    y[n] = x[n] * scale;
  }
  return scale;
}

// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
  auto dims = buffer.dimensions();
  if (dims.size() == 0) {
    return std::make_pair(0, 0);
  }
  return std::make_pair(buffer.element_count(), dims.back());
}

// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
                       ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
  auto [totalSize, lastDim] = GetDims(x);
  if (lastDim == 0) {
    return ffi::Error(ffi::ErrorCode::kInvalidArgument,
                      "RmsNorm input must be an array");
  }
  for (int64_t n = 0; n < totalSize; n += lastDim) {
    ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
  }
  return ffi::Error::Success();
}

// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    RmsNorm, RmsNormImpl,
    ffi::Ffi::Bind()
        .Attr<float>("eps")
        .Arg<ffi::Buffer<ffi::DataType::F32>>()  // x
        .Ret<ffi::Buffer<ffi::DataType::F32>>()  // y
);

Was able to generate librms_norm.so and install it.

import jax
import jax.numpy as jnp

def rms_norm_ref(x, eps=1e-5):
  scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
  return x / scale

import ctypes
from pathlib import Path
import jax.extend as jex

path = next(Path("ffi").glob("librms_norm*"))
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
jex.ffi.register_ffi_target(
    "rms_norm", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")

import numpy as np

def rms_norm(x, eps=1e-5):
  # We only implemented the `float32` version of this function, so we start by
  # checking the dtype. This check isn't strictly necessary because type
  # checking is also performed by the FFI when decoding input and output
  # buffers, but it can be useful to check types in Python to raise more
  # informative errors.
  if x.dtype != jnp.float32:
    raise ValueError("Only the float32 dtype is implemented by rms_norm")

  # In this case, the output of our FFI function is just a single array with the
  # same shape and dtype as the input. We discuss a case with a more interesting
  # output type below.
  out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)

  return jex.ffi.ffi_call(
    # The target name must be the same string as we used to register the target
    # above in `register_custom_call_target`
    "rms_norm",
    out_type,
    x,
    # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
    # the attribute `eps`. Our FFI function expects this to have the C++ `float`
    # type (which corresponds to numpy's `float32` type), and it must be a
    # static parameter (i.e. not a JAX array).
    eps=np.float32(eps),
    # The `vmap_method` parameter controls this function's behavior under `vmap`
    # as discussed below.
    vmap_method="broadcast_fullrank",
  )

# Test that this gives the same result as our reference implementation
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)

Hope this helps answer the issue. There are decorators on top of the example you provided so maybe the documentation has not been updated?

dfm commented 1 month ago

Oh I see what's happening here! The decorators are actually a red herring.

The issue is actually the vmap_method parameter. That was only added in JAX v0.4.34, so in the earlier version that you're using the vmap_method input is being interpreted as an attribute that you want to pass to the FFI handler. So, in JAX v0.4.33, you should use vectorized=True instead of vmap_method, although that behavior is deprecated going forward.

Hope this helps!

bridgesign commented 1 month ago

That was the issue. Thanks!