rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37 stars 1 forks source link

torch2jax

Documentation


This package is designed to facilitate no-copy PyTorch calling from JAX under both eager execution and JIT. It leverages the JAX C++ extension interface, enabling operations on both CPU and GPU platforms. Moreover, it allows for executing arbitrary PyTorch code from JAX under eager execution and JIT.

The intended application is efficiently running existing PyTorch code (like ML models) in JAX applications with very low overhead.

This project was inspired by the jax2torch repository https://github.com/lucidrains/jax2torch and has been made possible due to an amazing tutorial on extending JAX https://github.com/dfm/extending-jax. Comprehensive JAX documentation https://github.com/google/jax also significantly contributed to this work.

Although I am unsure this functionality could be achieved without C++/CUDA, the C++ compilation is efficiently done using PyTorch's portable CUDA & C++ compilation features, requiring minimal configuration.

Install

$ pip install git+https://github.com/rdyro/torch2jax.git

torch2jax is now available on PyPI under the alias wrap_torch2jax:

$ pip install wrap-torch2jax
$ # then
$ python3
$ >>> from wrap_torch2jax import torch2jax, torch2jax_with_vjp

Tested on:

Usage

With a single output

import torch
import jax
from jax import numpy as jnp
from torch2jax import torch2jax  # this converts a Python function to JAX
from torch2jax import Size, dtype_t2j  # this is torch.Size, a tuple-like shape representation

def torch_fn(a, b):
    return a + b

shape = (10, 2)
a, b = torch.randn(shape), torch.randn(shape)
jax_fn = torch2jax(torch_fn, a, b)  # without output_shapes, torch_fn **will be evaluated once**
jax_fn = torch2jax(torch_fn, a, b, output_shapes=Size(a.shape))  # torch_fn will NOT be evaluated

# you can specify the whole input and output structure without instantiating the tensors
# torch_fn will NOT be evaluated
jax_fn = torch2jax(
    torch_fn,
    jax.ShapeDtypeStruct(a.shape, dtype_t2j(a.dtype)),
    jax.ShapeDtypeStruct(b.shape, dtype_t2j(b.dtype)),
    output_shapes=jax.ShapeDtypeStruct(a.shape, dtype_t2j(a.dtype)),
)

prngkey = jax.random.PRNGKey(0)
device = jax.devices("cuda")[0]  # both CPU and CUDA are supported
a = jax.device_put(jax.random.normal(prngkey, shape), device)
b = jax.device_put(jax.random.normal(prngkey, shape), device)

# call the no-copy torch function
out = jax_fn(a, b)

# call the no-copy torch function **under JIT**
out = jax.jit(jax_fn)(a, b)

With a multiple outputs

def torch_fn(a, b):
    layer = torch.nn.Linear(2, 20).to(a)
    return a + b, torch.norm(a), layer(a * b)

shape = (10, 2)
a, b = torch.randn(shape), torch.randn(shape)
jax_fn = torch2jax(torch_fn, a, b)  # with example argumetns

prngkey = jax.random.PRNGKey(0)
device = jax.devices("cuda")[0]
a = jax.device_put(jax.random.normal(prngkey, shape), device)
b = jax.device_put(jax.random.normal(prngkey, shape), device)

# call the no-copy torch function
x, y, z = jax_fn(a, b)

# call the no-copy torch function **under JIT**
x, y, z = jax.jit(jax_fn)(a, b)

For a more advanced discussion on different ways of specifying input/output specification of the wrapped function, take a look at: input_output_specification.ipynb notebook in the examples folder.

Automatically defining gradients

Automatic reverse-mode gradient definitions are now supported for wrapped pytorch functions with the method torch2jax_with_vjp

import torch
import jax
from jax import numpy as jnp
import numpy as np

from torch2jax import torch2jax_with_vjp

def torch_fn(a, b):
  return torch.nn.MSELoss()(a, b)

shape = (6,)

xt, yt = torch.randn(shape), torch.randn(shape)

# `depth` determines how many times the function can be differentiated
jax_fn = torch2jax_with_vjp(torch_fn, xt, yt, depth=2) 

# we can now differentiate the function (derivatives are taken using PyTorch autodiff)
g_fn = jax.grad(jax_fn, argnums=(0, 1))
x, y = jnp.array(np.random.randn(*shape)), jnp.array(np.random.randn(*shape))

print(g_fn(x, y))

# JIT works too
print(jax.jit(g_fn)(x, y))

Caveats:

Dealing with Changing Shapes

You can deal with changing input shapes by calling torch2jax (and torch2jax_with_vjp) in the JAX function, both under JIT and eagerly!

@jax.jit
def compute(a, b, c):
    d = torch2jax_with_vjp(
        torch_fn,
        jax.ShapeDtypeStruct(a.shape, dtype_t2j(a.dtype)),
        jax.ShapeDtypeStruct(b.shape, dtype_t2j(b.dtype)),
        output_shapes=jax.ShapeDtypeStruct(a.shape, dtype_t2j(a.dtype)),
    )(a, b)
    return d - c

print(compute(a, b, a))

Timing Comparison vs pure_callback

This package achieves a much better performance when calling PyTorch code from JAX because it does not copy its input arguments and does not move CUDA data off the GPU.

Current Limitations of torch2jax

Changelog

Roadmap

Related Work

Our Python package wraps PyTorch code as-is (so custom code and mutating code will work!), but if you're looking for an automatic way to transcribe a supported subset of PyTorch code to JAX, take a look at https://github.com/samuela/torch2jax/tree/main.

We realize that two packages named the same is not ideal. As we work towards a solution, here's a stop-gap solution. We offer a helper script to install the package with an alias name, installing our package using pip under a different name.

  1. $ git clone https://github.com/rdyro/torch2jax.git - clone this repo
  2. $ python3 install_package_aliased.py new_name_torch2jax --install --test - install and test this package under the name new_name_torch2jax
  3. you can now use this package under the name new_name_torch2jax