samuela / torch2jax

Run PyTorch in JAX. 🤝
168 stars 5 forks source link

Interoperability with Keops? #4

Open adam-hartshorne opened 8 months ago

adam-hartshorne commented 8 months ago

The Keops library (https://www.kernel-operations.io/) has numpy / pytorch bindings. However, unless I am doing something wrong, I don't think torch2jax (as it stands) will support this due to the usage of LazyTensors? I was wondering if there is a way around this?

import torch 
import jax.numpy as jnp
import jax.random as jr
from jax.config import config

from pykeops.torch import LazyTensor
from torch2jax import j2t, t2j

config.update("jax_enable_x64", True)

M, N, D = 1000, 2000, 3
x = torch.randn(M, D, requires_grad=True).cuda() 
y = torch.randn(N, D).cuda()  

def test_func(x,y):
    x_i = LazyTensor(x.view(M, 1, D))  
    y_j = LazyTensor(y.view(1, N, D)) 

    # We can now perform large-scale computations, without memory overflows:
    D_ij = ((x_i - y_j)**2).sum(dim=2) 
    K_ij = (- D_ij).exp() 
    a_i = K_ij.sum(dim=1)
    return a_i

a_i = test_func(x,y)

key = jr.PRNGKey(0)

x_jax = jr.normal(key, (M, D))
y_jax = jr.normal(key, (N, D))
jax_test_func = t2j(test_func)
a_i = jax_test_func(x_jax, y_jax)
samuela commented 8 months ago

Hi @adam-hartshorne, thanks for reaching out! I'm not familiar with the internals of KeOps, so I'm afraid I may not be able to directly answer. If KeOps uses standard PyTorch functions under-the-hood then torch2jax should work fine. If they rely on custom PyTorch operators implemented in C++/CUDA, then torch2jax will not work. (Although perhaps it could! There ought to be a way to handle custom operators, but it's not something I've played around with thus far.)

It looks like LazyTensor is a concept in KeOps... Does it use custom operators under the hood?

Btw, what is the output of the code you listed? Does it work?

adam-hartshorne commented 8 months ago
   vals = jax_keops_calc_met_reduce(x, y, u, v, v_areas, kernel_params[0], kernel_params[1])
  File "/home/adam/anaconda3/envs/test_jax_keops/lib/python3.10/site-packages/torch2jax/__init__.py", line 365, in <lambda>
    t2j_function = lambda f: lambda *args: f(*jax.tree_util.tree_map(Torchish, args)).value
  File "/media/adam/shared_drive/PycharmProjects/test_jax_keops/backends/keops_fns.py", line 35, in Var_met_reduce
    x = LazyTensor(x[:, None, :])
  File "/home/adam/anaconda3/envs/test_jax_keops/lib/python3.10/site-packages/pykeops/torch/lazytensor/LazyTensor.py", line 64, in __init__
    super().__init__(x=x, axis=axis)
  File "/home/adam/anaconda3/envs/test_jax_keops/lib/python3.10/site-packages/pykeops/common/lazy_tensor.py", line 164, in __init__
    self._dtype = self.tools.dtypename(self.tools.dtype(x))
  File "/home/adam/anaconda3/envs/test_jax_keops/lib/python3.10/site-packages/pykeops/torch/utils.py", line 128, in dtypename
    raise ValueError(
ValueError: [KeOps] float64 data type incompatible with KeOps.

https://www.kernel-operations.io/keops/engine/lazy_tensors.html

Yes Keops is a C++ / Cuda library.

The other torch2jax library (https://github.com/rdyro/torch2jax) works with Keops, because it takes a different approach where it leverages the ability to define / build custom ops. I wonder if it might be worth collaborating so that we get one torch2jax library which has different modes that could be enabled either automatically (or via a flag), thus covering difference scenarios and use cases?

samuela commented 8 months ago

Hmm, that error is actually being raised by KeOps and not torch2jax. What happens if you run that example code in float32 instead of float64?

so that we get one torch2jax library which has different modes that could be enabled either automatically (or via a flag), thus covering difference scenarios and use cases?

I think the easiest path to this future would be to allow end-users to extend torch2jax (this one) for their custom ops. That should do the trick. I don't think it would be all that hard actually, but it's not something that has been explored yet.

adam-hartshorne commented 8 months ago

Hmm, that error is actually being raised by KeOps and not torch2jax. What happens if you run that example code in float32 instead of float64?

ValueError: [KeOps] float32 data type incompatible with KeOps.

I think the easiest path to this future would be to allow end-users to extend torch2jax (this one) for their custom ops. That should do the trick. I don't think it would be all that hard actually, but it's not something that has been explored yet.

JAX method for creating custom ops is extremely confusing, poorly documented and rather tedious (unlike with Tensorflow or Pytorch). The other torch2jax library handles this seamlessly with minimal code required by the user. Hence my suggestion in regards to exploring combining efforts.

samuela commented 8 months ago

ValueError: [KeOps] float32 data type incompatible with KeOps.

Hmm, something seems fishy here. This seems like a very weird error for KeOps to be throwing.

JAX method for creating custom ops is extremely confusing, poorly documented and rather tedious (unlike with Tensorflow or Pytorch). The other torch2jax library handles this seamlessly with minimal code required by the user. Hence my suggestion in regards to exploring combining efforts.

One option here would be to use different conversion libraries on the different pieces of your model and then merge them once they're both JAX-ified. I'm resistant to bringing in the additional complexity of having two distinct forms of conversion in a single library.

Also, overriding custom PyTorch ops does not necessarily require implementing those ops in custom JAX operators: They could instead be implemented with vanilla JAX/XLA ops. There's pros and cons either way, but at a minimum custom ops on the PyTorch side will require some customization on the torch2jax side. There is not a straightforward way to automagically convert PyTorch custom ops to XLA custom ops AFAIK.

adam-hartshorne commented 8 months ago

Hmm, something seems fishy here. This seems like a very weird error for KeOps to be throwing.

It works with native PyTorch and it works with the other torch2jax library. So it is definitely something failing specifically with regards to how the conversion is being attempted.

I'm resistant to bringing in the additional complexity of having two distinct forms of conversion in a single library.

Fair enough.