Closed adam-hartshorne closed 1 month ago
Thanks for letting me know!
I'm particularly interested in whether this could give a performance boost if we can sync the stream less conservatively.
I'll keep this issue open until I have a chance to look into this.
"Added jax.extend.ffi.ffi_call and jax.extend.ffi.ffi_lowering to support the use of the new ffi-tutorial to interface with custom C++ and CUDA code from JAX."
https://github.com/google/jax/releases/tag/jax-v0.4.32
There also mentions of things that I presume torch2jax uses like mhlo and jax.dlpack.from_dlpack being remove / deprecated.
Ok, I had a look and we're still relying on the old custom call, but it shouldn't be deprecated for a while.
The dlpack is now a warning, it's an easy fix, but I'd rather bundle it together with the new FFI port release.
I'll keep monitoring compatibility, thanks for the heads up!
Ok cool.
We have been playing with the new FFI functionality to call some custom cuda kernels; so far, it all seems reasonably straightforward and works well.
Yeah, I haven't tried it yet, but it's much more straightforward, which is very nice for fast prototyping especially.
I thought I would mark your card that it appears JAX have released a new way of calling c++ / CUDA functions,
https://jax.readthedocs.io/en/latest/ffi.html
I don't know if that will make it better / easier to run PyTorch code going forward.