This package is designed to facilitate no-copy PyTorch calling from JAX under both eager execution and JIT. It leverages the JAX C++ extension interface, enabling operations on both CPU and GPU platforms. Moreover, it allows for executing arbitrary PyTorch code from JAX under eager execution and JIT.
The intended application is efficiently running existing PyTorch code (like ML models) in JAX applications with very low overhead.
This project was inspired by the jax2torch repository https://github.com/lucidrains/jax2torch and has been made possible due to an amazing tutorial on extending JAX https://github.com/dfm/extending-jax. Comprehensive JAX documentation https://github.com/google/jax also significantly contributed to this work.
Although I am unsure this functionality could be achieved without C++/CUDA, the C++ compilation is efficiently done using PyTorch's portable CUDA & C++ compilation features, requiring minimal configuration.
$ pip install git+https://github.com/rdyro/torch2jax.git
torch2jax
is now available on PyPI under the alias wrap_torch2jax
:
$ pip install wrap-torch2jax
$ # then
$ python3
$ >>> from wrap_torch2jax import torch2jax, torch2jax_with_vjp
Tested on:
3.9 3.10 3.11 3.12
& JAX Versions 0.4.26 0.4.27 0.4.28 0.4.29 0.4.30 0.4.31
3.9 3.10 3.11 3.12
& JAX Versions 0.4.30 0.4.31
With a single output
import torch
import jax
from jax import numpy as jnp
from torch2jax import torch2jax # this converts a Python function to JAX
from torch2jax import Size, dtype_t2j # this is torch.Size, a tuple-like shape representation
def torch_fn(a, b):
return a + b
shape = (10, 2)
a, b = torch.randn(shape), torch.randn(shape)
jax_fn = torch2jax(torch_fn, a, b) # without output_shapes, torch_fn **will be evaluated once**
jax_fn = torch2jax(torch_fn, a, b, output_shapes=Size(a.shape)) # torch_fn will NOT be evaluated
# you can specify the whole input and output structure without instantiating the tensors
# torch_fn will NOT be evaluated
jax_fn = torch2jax(
torch_fn,
jax.ShapeDtypeStruct(a.shape, dtype_t2j(a.dtype)),
jax.ShapeDtypeStruct(b.shape, dtype_t2j(b.dtype)),
output_shapes=jax.ShapeDtypeStruct(a.shape, dtype_t2j(a.dtype)),
)
prngkey = jax.random.PRNGKey(0)
device = jax.devices("cuda")[0] # both CPU and CUDA are supported
a = jax.device_put(jax.random.normal(prngkey, shape), device)
b = jax.device_put(jax.random.normal(prngkey, shape), device)
# call the no-copy torch function
out = jax_fn(a, b)
# call the no-copy torch function **under JIT**
out = jax.jit(jax_fn)(a, b)
With a multiple outputs
def torch_fn(a, b):
layer = torch.nn.Linear(2, 20).to(a)
return a + b, torch.norm(a), layer(a * b)
shape = (10, 2)
a, b = torch.randn(shape), torch.randn(shape)
jax_fn = torch2jax(torch_fn, a, b) # with example argumetns
prngkey = jax.random.PRNGKey(0)
device = jax.devices("cuda")[0]
a = jax.device_put(jax.random.normal(prngkey, shape), device)
b = jax.device_put(jax.random.normal(prngkey, shape), device)
# call the no-copy torch function
x, y, z = jax_fn(a, b)
# call the no-copy torch function **under JIT**
x, y, z = jax.jit(jax_fn)(a, b)
For a more advanced discussion on different ways of specifying input/output
specification of the wrapped function, take a look at:
input_output_specification.ipynb
notebook in the examples
folder.
Automatic reverse-mode gradient definitions are now supported for wrapped
pytorch functions with the method torch2jax_with_vjp
import torch
import jax
from jax import numpy as jnp
import numpy as np
from torch2jax import torch2jax_with_vjp
def torch_fn(a, b):
return torch.nn.MSELoss()(a, b)
shape = (6,)
xt, yt = torch.randn(shape), torch.randn(shape)
# `depth` determines how many times the function can be differentiated
jax_fn = torch2jax_with_vjp(torch_fn, xt, yt, depth=2)
# we can now differentiate the function (derivatives are taken using PyTorch autodiff)
g_fn = jax.grad(jax_fn, argnums=(0, 1))
x, y = jnp.array(np.random.randn(*shape)), jnp.array(np.random.randn(*shape))
print(g_fn(x, y))
# JIT works too
print(jax.jit(g_fn)(x, y))
Caveats:
jax.hessian(f)
will not work since torch2jax
uses forward differentiation, but
the same functionality can be achieved using jax.jacobian(jax.jacobian(f))
torch2jax_with_vjp/torch2jax
again if you need to alter the input shapesYou can deal with changing input shapes by calling torch2jax
(and
torch2jax_with_vjp
) in the JAX function, both under JIT and eagerly!
@jax.jit
def compute(a, b, c):
d = torch2jax_with_vjp(
torch_fn,
jax.ShapeDtypeStruct(a.shape, dtype_t2j(a.dtype)),
jax.ShapeDtypeStruct(b.shape, dtype_t2j(b.dtype)),
output_shapes=jax.ShapeDtypeStruct(a.shape, dtype_t2j(a.dtype)),
)(a, b)
return d - c
print(compute(a, b, a))
pure_callback
This package achieves a much better performance when calling PyTorch code from JAX because it does not copy its input arguments and does not move CUDA data off the GPU.
torch2jax
output_shapes=
kw argument) representations (for
flexibility in input and output structure) must be wrapped in torch.Size
or
jax.ShapeDtypeStruct
version 0.4.11
version 0.4.10
.device()
for
.devices()
no version change
install_package_aliased.py
to automatically install
the package with a different name (to avoid a name conflict)version 0.4.7
version 0.4.6
version 0.4.5
torch2jax_with_vjp
now automatically selects use_torch_vjp=False
if the True
failsversion 0.4.4
use_torch_vjp
(defaulting to True) flag in torch2jax_with_vjp
which
can be set to False to use the old torch.autograd.grad
for taking
gradients, it is the slower method, but is more compatibleversion 0.4.3
version 0.4.2
examples/input_output_specification.ipynb
showing how input/output
structure can be specifiedversion 0.4.1
torch2jax_with_vjp
, nondiff arguments were erroneously memorizedversion 0.4.0
torch.vmap
, this makes jax.jacobian
workexamples/bert_from_jax.ipynb
version 0.3.0
torch2jax_with_vjp
which
allows recursively defining reverse-mode gradients for the wrapped torch
function that works in JAX both normally and under JITversion 0.2.0
torch2jax.compat.torch2jax
version 0.1.2
torch.cuda.is_available()
version 0.1.1
version 0.1.0
jax.vmap
)Our Python package wraps PyTorch code as-is (so custom code and mutating code will work!), but if you're looking for an automatic way to transcribe a supported subset of PyTorch code to JAX, take a look at https://github.com/samuela/torch2jax/tree/main.
We realize that two packages named the same is not ideal. As we work towards a solution, here's a stop-gap solution. We offer a helper script to install the package with an alias name, installing our package using pip under a different name.
$ git clone https://github.com/rdyro/torch2jax.git
- clone this repo$ python3 install_package_aliased.py new_name_torch2jax --install --test
- install and test this package under the name new_name_torch2jax
new_name_torch2jax