rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37 stars 1 forks source link

New JAX Functionality For Calling "Foreign" Code #16

Closed adam-hartshorne closed 1 month ago

adam-hartshorne commented 1 month ago

I thought I would mark your card that it appears JAX have released a new way of calling c++ / CUDA functions,

https://jax.readthedocs.io/en/latest/ffi.html

I don't know if that will make it better / easier to run PyTorch code going forward.

rdyro commented 1 month ago

Thanks for letting me know!

I'm particularly interested in whether this could give a performance boost if we can sync the stream less conservatively.

I'll keep this issue open until I have a chance to look into this.

adam-hartshorne commented 1 month ago

"Added jax.extend.ffi.ffi_call and jax.extend.ffi.ffi_lowering to support the use of the new ffi-tutorial to interface with custom C++ and CUDA code from JAX."

https://github.com/google/jax/releases/tag/jax-v0.4.32

There also mentions of things that I presume torch2jax uses like mhlo and jax.dlpack.from_dlpack being remove / deprecated.

rdyro commented 1 month ago

Ok, I had a look and we're still relying on the old custom call, but it shouldn't be deprecated for a while.

The dlpack is now a warning, it's an easy fix, but I'd rather bundle it together with the new FFI port release.

I'll keep monitoring compatibility, thanks for the heads up!

adam-hartshorne commented 1 month ago

Ok cool.

We have been playing with the new FFI functionality to call some custom cuda kernels; so far, it all seems reasonably straightforward and works well.

rdyro commented 1 month ago

Yeah, I haven't tried it yet, but it's much more straightforward, which is very nice for fast prototyping especially.