google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.76k stars 2.72k forks source link

Cannot execute custom operation with XLA typed FFI #22499

Open tchatow opened 1 month ago

tchatow commented 1 month ago

Description

I am attempting to add a custom operation using the typed (rather than untyped( XLA FFI api. However, I get a warning/error when trying to use it that that the symbol cannot be found:

simple_orc_jit.cc:433] Unable to resolve runtime symbol: 'my_func'. Hint: if the symbol a custom call target, make sure you've registered it with the JIT using XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.
JIT session error: Symbols not found: [ my_func ]

Unlike in the examples, my custom operation uses the typed interface like

xla::ffi::Error my_func(
   xla::ffi::Buffer<xla::ffi::DataType::F32> buf0,
   xla::ffi::Buffer<xla::ffi::DataType::S32> buf1, 
   xla::ffi::Result<xla::ffi::Buffer<xla::ffi::DataType::F32>> out) { ... }

XLA_FFI_DEFINE_HANDLER_SYMBOL(
    my_func_handler, my_func,
    xla::ffi::Ffi::Bind()
        .Arg<xla::ffi::Buffer<xla::ffi::DataType::F32>>()
        .Arg<xla::ffi::Buffer<xla::ffi::DataType::S32>>()
        .Ret<xla::ffi::Buffer<xla::ffi::DataType::F32>>());

In another C++ file:

PYBIND11_MODULE(my_func_mod, m) {
  m.def("my_func_registrations", []() {
    pybind11::dict dict;
    dict["my_func"] = pybind11::capsule(reinterpret_cast<void*>(my_func_handler));
    return dict;
  });
}

In python (some boilerplate hidden):

for _name, _value in my_func_mod.my_func_registrations().items():
    xla_client.register_custom_call_target(_name, {"execute": _value}, platform="cpu", api_version=1)

def _my_func_lowering(ctx, a, b):
    a_type = ir.RankedTensorType(a.type)
    b_type = ir.RankedTensorType(b.type)

    return custom_call(
        'my_func',
        result_types=[a_type],
        operands=[a, b],
        operand_layouts=default_layouts(a_type.shape, b_type.shape),
        result_layouts=default_layouts(a_type.shape),
        api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI
    ).results

Reading the XLA source suggests that SimpleOrcJIT looks up custom symbols in xla::CustomCallTargetRegistry, which is only populated for api_version == 0 in PyRegisterCustomCallTarget.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.0
dfm commented 1 month ago

Thanks for the report. This is absolutely expected to work! I'm currently writing up a tutorial about this (see https://github.com/google/jax/pull/22095) and I haven't had any issues running ~equivalent code on CPU with all the same versions of JAX that you're using.

I'll see if I can dig into this a bit, but how about you take a look at the draft of my tutorial and see if you notice anything useful there. Everything you've included as sample code here looks right to me on a first pass!

tchatow commented 1 month ago

Also I confirmed the untyped ffi works correctly with the same registration code and changing the api version. Do I need XLA_FFI_REGISTER_HANDLER as well?

I took a look at the tutorial and it uses jax.extend.ffi.ffi_call. I can't find this in my distribution - is it new?

dfm commented 1 month ago

You shouldn't need to call XLA_FFI_REGISTER_HANDLER. The references to that macro in jaxlib are unused - they're only for users who use jaxlib without the Python frontend. The only way to register user FFI calls is via Python.

I took a look at the tutorial and it uses jax.extend.ffi.ffi_call. I can't find this in my distribution - is it new?

I thought it got into 0.4.30, but it looks like I actually missed that cutoff (here's the PR: https://github.com/google/jax/pull/21925). It will be included in the next release!

dfm commented 1 month ago

I wanted to check in to confirm that you were able to get the code working, @tchatow. Is that right? I'm still not totally sure what the problem was before, but if you have something that works for now, let's close the issue and you can open a new one if you run into more problems. These APIs might still have some rough edges and I'm keen to get them sorted out!

tchatow commented 1 month ago

No I wasn't able to get it working. I switched to the untyped api which works fine. Once the next release is out I can test it again with ffi_call.

dfm commented 1 month ago

Oh I see. I don't expect that ffi_call will solve this issue though. Can you put together a complete minimal example (i.e. the full C++ and Python files and tell me how you're compiling it) so that I can reproduce the issue locally?