Open tchatow opened 2 months ago
Thanks for the report. This is absolutely expected to work! I'm currently writing up a tutorial about this (see https://github.com/google/jax/pull/22095) and I haven't had any issues running ~equivalent code on CPU with all the same versions of JAX that you're using.
I'll see if I can dig into this a bit, but how about you take a look at the draft of my tutorial and see if you notice anything useful there. Everything you've included as sample code here looks right to me on a first pass!
Also I confirmed the untyped ffi works correctly with the same registration code and changing the api version. Do I need XLA_FFI_REGISTER_HANDLER
as well?
I took a look at the tutorial and it uses jax.extend.ffi.ffi_call
. I can't find this in my distribution - is it new?
You shouldn't need to call XLA_FFI_REGISTER_HANDLER
. The references to that macro in jaxlib
are unused - they're only for users who use jaxlib
without the Python frontend. The only way to register user FFI calls is via Python.
I took a look at the tutorial and it uses
jax.extend.ffi.ffi_call
. I can't find this in my distribution - is it new?
I thought it got into 0.4.30, but it looks like I actually missed that cutoff (here's the PR: https://github.com/google/jax/pull/21925). It will be included in the next release!
I wanted to check in to confirm that you were able to get the code working, @tchatow. Is that right? I'm still not totally sure what the problem was before, but if you have something that works for now, let's close the issue and you can open a new one if you run into more problems. These APIs might still have some rough edges and I'm keen to get them sorted out!
No I wasn't able to get it working. I switched to the untyped api which works fine. Once the next release is out I can test it again with ffi_call
.
Oh I see. I don't expect that ffi_call
will solve this issue though. Can you put together a complete minimal example (i.e. the full C++ and Python files and tell me how you're compiling it) so that I can reproduce the issue locally?
Description
I am attempting to add a custom operation using the typed (rather than untyped( XLA FFI api. However, I get a warning/error when trying to use it that that the symbol cannot be found:
Unlike in the examples, my custom operation uses the typed interface like
In another C++ file:
In python (some boilerplate hidden):
Reading the XLA source suggests that
SimpleOrcJIT
looks up custom symbols inxla::CustomCallTargetRegistry
, which is only populated forapi_version == 0
inPyRegisterCustomCallTarget
.System info (python version, jaxlib version, accelerator, etc.)