Closed bridgesign closed 1 month ago
Thanks for this report! That example is being run as part of CI so I expect there's something different in your setup. Can you share the exact C++ and Python code that you're running?
Edited to add: You might also consider checking out this version of that example which includes all the packaging details, etc.: https://github.com/jax-ml/jax/tree/main/examples/ffi
Thanks the example helps! I tried to copy the parts from the documentation in order.
// rms_norm.cc
#include <functional>
#include <numeric>
#include <utility>
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
namespace ffi = xla::ffi;
#include <cmath>
#include <cstdint>
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
float sm = 0.0f;
for (int64_t n = 0; n < size; ++n) {
sm += x[n] * x[n];
}
float scale = 1.0f / std::sqrt(sm / float(size) + eps);
for (int64_t n = 0; n < size; ++n) {
y[n] = x[n] * scale;
}
return scale;
}
// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
auto dims = buffer.dimensions();
if (dims.size() == 0) {
return std::make_pair(0, 0);
}
return std::make_pair(buffer.element_count(), dims.back());
}
// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
}
return ffi::Error::Success();
}
// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
);
Was able to generate librms_norm.so and install it.
import jax
import jax.numpy as jnp
def rms_norm_ref(x, eps=1e-5):
scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
return x / scale
import ctypes
from pathlib import Path
import jax.extend as jex
path = next(Path("ffi").glob("librms_norm*"))
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
jex.ffi.register_ffi_target(
"rms_norm", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")
import numpy as np
def rms_norm(x, eps=1e-5):
# We only implemented the `float32` version of this function, so we start by
# checking the dtype. This check isn't strictly necessary because type
# checking is also performed by the FFI when decoding input and output
# buffers, but it can be useful to check types in Python to raise more
# informative errors.
if x.dtype != jnp.float32:
raise ValueError("Only the float32 dtype is implemented by rms_norm")
# In this case, the output of our FFI function is just a single array with the
# same shape and dtype as the input. We discuss a case with a more interesting
# output type below.
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jex.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_custom_call_target`
"rms_norm",
out_type,
x,
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
eps=np.float32(eps),
# The `vmap_method` parameter controls this function's behavior under `vmap`
# as discussed below.
vmap_method="broadcast_fullrank",
)
# Test that this gives the same result as our reference implementation
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)
Hope this helps answer the issue. There are decorators on top of the example you provided so maybe the documentation has not been updated?
Oh I see what's happening here! The decorators are actually a red herring.
The issue is actually the vmap_method
parameter. That was only added in JAX v0.4.34, so in the earlier version that you're using the vmap_method
input is being interpreted as an attribute that you want to pass to the FFI handler. So, in JAX v0.4.33, you should use vectorized=True
instead of vmap_method
, although that behavior is deprecated going forward.
Hope this helps!
That was the issue. Thanks!
Description
I am trying to create a custom C extension for JAX and was trying out the example given in the documentation. But when I try to run the example, I get the following error:
System info (python version, jaxlib version, accelerator, etc.)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. jax: 0.4.33 jaxlib: 0.4.33 numpy: 2.1.1 python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Linux', node='027b41159048', release='6.8.0-40-generic', version='#40~22.04.3-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 30 17:30:19 UTC 2', machine='x86_64')
$ nvidia-smi Fri Oct 4 18:19:24 2024
+---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.183.01 Driver Version: 535.183.01 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA TITAN X (Pascal) Off | 00000000:01:00.0 On | N/A | | 33% 57C P0 62W / 250W | 3731MiB / 12288MiB | 7% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| +---------------------------------------------------------------------------------------+