Closed adam-hartshorne closed 1 year ago
Great, I'm glad it might be found useful!
Yes, gradients are not defined, but defining them would be done via custom_vjp
in Python as done in https://github.com/rdyro/jfi-JAXFriendlyInterface. Porting this requires a little interface alignment, so I'll get to it a little later (hopefully this week).
I'm currently split on supporting vmap since calling a pytorch function repeatedly is probably not a good idea (as opposed to calling a batched PyTorch function), but https://github.com/dfm/extending-jax/blob/main/src/kepler_jax/kepler_jax.py contains an example of how to do it, so I might also get to it later.
I added an overhead plot (in README) for a function with a signature:
def torch_fn(a, b):
return a + b
take a look!
I added gradients (via vjp)!
I also ended up supporting batching rules efficiently using torch.vmap
. There's also support for mixed input/output types (e.g., language models have integer inputs which was previously not supported by torch2jax
).
Check out the version 0.4.0 now, an example wrapping a BERT model is in examples/bert_from_jax.py
I'll add more robust overhead testing soon.
That's awesome. I am a bit tied up for the next few days, but then I will definitely test it out.
Of particular interest to me (and I believe quite a lot of other people) and why I started exploring the potential for calling PyTorch from Jax is in order to interact with the following library,
https://github.com/getkeops/keops
However, quickly looking at the source code of torch2jax I am starting to think a dedicated custom operator is probably the way forward (although tutorials on this are very limited compared to writing custom ops in PyTorch or Tensorflow).
Nice!
At this point in development torch2jax
and torch2jax_with_vjp
should be able to handle quite general operations, including third party libraries.
I've tested both BERT and ResNet architectures with it and both forward and backward passes have acceptable performance. I have yet to publish these tests.
Using examples from https://github.com/dfm/extending-jax and this torch2jax
repo should make defining your own op more straightforward (but the available tutorials are of course nowhere near as developed as PyTorch or Tensorflow, like you said).
torch2jax
should give you a frictionless way of calling this 3rd party library for quick development testing!
I finally found time to test out your fantastic work.
First attempt with the code attached below worked perfectly straight out the box. It uses the Keops library, which itself builds these special lazy tensor functions into a optimised cuda code, that meshes with Pytorch. I still need to test this instead an optimisation loop i.e. gradients getting called.
import torch
import jax
import jax.random as jrandom
from torch2jax import torch2jax, torch2jax_with_vjp
from torch2jax import Size
from pykeops.torch import LazyTensor
def test_keops(x, y):
"""(B,N,D), (B,N,D), (B,N,1) -> (B,N,1)"""
x_i = LazyTensor( x[:,None,:] )
y_j = LazyTensor( y[None,:,:] )
D_ij = ((x_i - y_j) ** 2).sum(dim=2)
K_ij = (- D_ij).exp()
a_i = K_ij.sum(dim=1)
return a_i
shape_a = (100000, 3)
shape_b = (200000, 3)
a, b = torch.zeros(shape_a), torch.zeros(shape_b)
jax_fn = torch2jax(test_keops, a, b, output_shapes=Size((a.shape[0], 1))) # with example arguments
prngkey = jrandom.PRNGKey(0)
a = jrandom.normal(prngkey, shape_a)
b = jrandom.normal(prngkey, shape_b)
print(a.device_buffer.device())
dist = jax_fn(a, b)
print(dist)
dist = jax.jit(jax_fn)(a, b)
print(dist)
That's super cool, I'm glad it just worked!
Performance-wise, the call to torch practically issues torch.cuda.synchronize()
(to ensure data is synchronized) at the end of the call, but normally PyTorch's eager evaluation efficiency comes from batching chains of cuda calls.
I'm not sure if this is applicable to your application, but if you have e.g., PyTorch training loops you're calling from JAX, unrolling the PyTorch loop (to say 10 iterations, i.e., just having for i in range(10):
inside your PyTorch function), should speed things up.
This looks very cool. Great work. Once some more of the Roadmap has been ticked off, I can see this potentially getting a lot of use from JAX community.
One question - Aside from currently no vmap support (as mentioned in the Roadmap), looking through the code am I correct in thinking that gradients via backward / forward autodiff (jvp / vjp) aren't yet supported i.e. I can't include the wrapped PyTorch function as part of an optimisation?
Also, have you run any tests looking into if / how much of an overhead penalty there is?