rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37 stars 1 forks source link

Great Work - Question about Gradients. #1

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

This looks very cool. Great work. Once some more of the Roadmap has been ticked off, I can see this potentially getting a lot of use from JAX community.

One question - Aside from currently no vmap support (as mentioned in the Roadmap), looking through the code am I correct in thinking that gradients via backward / forward autodiff (jvp / vjp) aren't yet supported i.e. I can't include the wrapped PyTorch function as part of an optimisation?

Also, have you run any tests looking into if / how much of an overhead penalty there is?

rdyro commented 1 year ago

Great, I'm glad it might be found useful!

Yes, gradients are not defined, but defining them would be done via custom_vjp in Python as done in https://github.com/rdyro/jfi-JAXFriendlyInterface. Porting this requires a little interface alignment, so I'll get to it a little later (hopefully this week).

I'm currently split on supporting vmap since calling a pytorch function repeatedly is probably not a good idea (as opposed to calling a batched PyTorch function), but https://github.com/dfm/extending-jax/blob/main/src/kepler_jax/kepler_jax.py contains an example of how to do it, so I might also get to it later.

I added an overhead plot (in README) for a function with a signature:

def torch_fn(a, b):
    return a + b

take a look!

rdyro commented 1 year ago

I added gradients (via vjp)!

I also ended up supporting batching rules efficiently using torch.vmap. There's also support for mixed input/output types (e.g., language models have integer inputs which was previously not supported by torch2jax).

Check out the version 0.4.0 now, an example wrapping a BERT model is in examples/bert_from_jax.py

I'll add more robust overhead testing soon.

adam-hartshorne commented 1 year ago

That's awesome. I am a bit tied up for the next few days, but then I will definitely test it out.

Of particular interest to me (and I believe quite a lot of other people) and why I started exploring the potential for calling PyTorch from Jax is in order to interact with the following library,

https://github.com/getkeops/keops

However, quickly looking at the source code of torch2jax I am starting to think a dedicated custom operator is probably the way forward (although tutorials on this are very limited compared to writing custom ops in PyTorch or Tensorflow).

rdyro commented 1 year ago

Nice!

At this point in development torch2jax and torch2jax_with_vjp should be able to handle quite general operations, including third party libraries.

I've tested both BERT and ResNet architectures with it and both forward and backward passes have acceptable performance. I have yet to publish these tests.

Using examples from https://github.com/dfm/extending-jax and this torch2jax repo should make defining your own op more straightforward (but the available tutorials are of course nowhere near as developed as PyTorch or Tensorflow, like you said).

torch2jax should give you a frictionless way of calling this 3rd party library for quick development testing!

adam-hartshorne commented 1 year ago

I finally found time to test out your fantastic work.

First attempt with the code attached below worked perfectly straight out the box. It uses the Keops library, which itself builds these special lazy tensor functions into a optimised cuda code, that meshes with Pytorch. I still need to test this instead an optimisation loop i.e. gradients getting called.

import torch
import jax
import jax.random as jrandom
from torch2jax import torch2jax, torch2jax_with_vjp 
from torch2jax import Size 
from pykeops.torch import LazyTensor

def test_keops(x, y):
    """(B,N,D), (B,N,D), (B,N,1) -> (B,N,1)"""
    x_i = LazyTensor( x[:,None,:] )
    y_j = LazyTensor( y[None,:,:] )
    D_ij = ((x_i - y_j) ** 2).sum(dim=2)
    K_ij = (- D_ij).exp()
    a_i = K_ij.sum(dim=1)
    return a_i

shape_a = (100000, 3)
shape_b = (200000, 3)
a, b = torch.zeros(shape_a), torch.zeros(shape_b)
jax_fn = torch2jax(test_keops, a, b, output_shapes=Size((a.shape[0], 1)))  # with example arguments

prngkey = jrandom.PRNGKey(0)
a = jrandom.normal(prngkey, shape_a)
b = jrandom.normal(prngkey, shape_b)

print(a.device_buffer.device())

dist = jax_fn(a, b)
print(dist)

dist = jax.jit(jax_fn)(a, b)
print(dist)
rdyro commented 1 year ago

That's super cool, I'm glad it just worked!

Performance-wise, the call to torch practically issues torch.cuda.synchronize() (to ensure data is synchronized) at the end of the call, but normally PyTorch's eager evaluation efficiency comes from batching chains of cuda calls.

I'm not sure if this is applicable to your application, but if you have e.g., PyTorch training loops you're calling from JAX, unrolling the PyTorch loop (to say 10 iterations, i.e., just having for i in range(10): inside your PyTorch function), should speed things up.