rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37 stars 1 forks source link

CUDA bug and JAX breaking change #15

Closed zwei-beiner closed 2 months ago

zwei-beiner commented 2 months ago

Hi, I came across the following bugs:

When I install the current up-to-date version of torch2jax with

pip install git+https://github.com/rdyro/torch2jax.git

I get the following error

ImportError: /.cache/torch2jax/cpython-311-linux-x86_64-torch2jax-0.4.10/torch2jax_cpp.so: undefined symbol: _Z18actual_cuda_deviceRK15TorchCallDevicePv

Potentially, this is coming from the multi-gpu update. I'm running this on a machine with no GPU, only CPU. Possibly the CPU-only case is not handled by the multi-gpu code?

So I installed the following commit (pre-multi-gpu):

pip install git+https://github.com/rdyro/torch2jax.git@04b48ee9eb3846b5829ed8aeb26c6688eb110a25

Now, the first bug disappears (as expected), but I now get the following bug:

AttributeError: module 'jax.interpreters.mlir' has no attribute 'ir_constants'. Did you mean: 'ir_constant'?

Maybe this is related to the latest breaking change in jax 0.4.32 (see the point under "Breaking changes" here: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-32).

System info: Ran everything in a clean venv with jax and torch installed from scratch with python 3.11. Ran the example code in the README file:

import torch
import jax
from jax import numpy as jnp
import numpy as np

from torch2jax import torch2jax_with_vjp

def torch_fn(a, b):
  return torch.nn.MSELoss()(a, b)

shape = (6,)

xt, yt = torch.randn(shape), torch.randn(shape)

# `depth` determines how many times the function can be differentiated
jax_fn = torch2jax_with_vjp(torch_fn, xt, yt, depth=2) 

# we can now differentiate the function (derivatives are taken using PyTorch autodiff)
g_fn = jax.grad(jax_fn, argnums=(0, 1))
x, y = jnp.array(np.random.randn(*shape)), jnp.array(np.random.randn(*shape))

print(g_fn(x, y))

# JIT works too
print(jax.jit(g_fn)(x, y))
rdyro commented 2 months ago

You're right! I was missing this case in my tests and it is indeed broken.

I fixed this now (and expanded the tests). Can you reinstall the package now (version 0.4.11) and let me know if it works for you, please?

rdyro commented 2 months ago

The new JAX release seems to work, it was a deeper problem, but it should be now fixed.

zwei-beiner commented 2 months ago

Thanks a lot, the code works now. Again, thanks for the quick response and making this package!

rdyro commented 2 months ago

Awesome!