Open adam-hartshorne opened 8 months ago
Hi @adam-hartshorne, thanks for reaching out! I'm not familiar with the internals of KeOps, so I'm afraid I may not be able to directly answer. If KeOps uses standard PyTorch functions under-the-hood then torch2jax should work fine. If they rely on custom PyTorch operators implemented in C++/CUDA, then torch2jax will not work. (Although perhaps it could! There ought to be a way to handle custom operators, but it's not something I've played around with thus far.)
It looks like LazyTensor
is a concept in KeOps... Does it use custom operators under the hood?
Btw, what is the output of the code you listed? Does it work?
vals = jax_keops_calc_met_reduce(x, y, u, v, v_areas, kernel_params[0], kernel_params[1])
File "/home/adam/anaconda3/envs/test_jax_keops/lib/python3.10/site-packages/torch2jax/__init__.py", line 365, in <lambda>
t2j_function = lambda f: lambda *args: f(*jax.tree_util.tree_map(Torchish, args)).value
File "/media/adam/shared_drive/PycharmProjects/test_jax_keops/backends/keops_fns.py", line 35, in Var_met_reduce
x = LazyTensor(x[:, None, :])
File "/home/adam/anaconda3/envs/test_jax_keops/lib/python3.10/site-packages/pykeops/torch/lazytensor/LazyTensor.py", line 64, in __init__
super().__init__(x=x, axis=axis)
File "/home/adam/anaconda3/envs/test_jax_keops/lib/python3.10/site-packages/pykeops/common/lazy_tensor.py", line 164, in __init__
self._dtype = self.tools.dtypename(self.tools.dtype(x))
File "/home/adam/anaconda3/envs/test_jax_keops/lib/python3.10/site-packages/pykeops/torch/utils.py", line 128, in dtypename
raise ValueError(
ValueError: [KeOps] float64 data type incompatible with KeOps.
https://www.kernel-operations.io/keops/engine/lazy_tensors.html
Yes Keops is a C++ / Cuda library.
The other torch2jax library (https://github.com/rdyro/torch2jax) works with Keops, because it takes a different approach where it leverages the ability to define / build custom ops. I wonder if it might be worth collaborating so that we get one torch2jax library which has different modes that could be enabled either automatically (or via a flag), thus covering difference scenarios and use cases?
Hmm, that error is actually being raised by KeOps and not torch2jax. What happens if you run that example code in float32 instead of float64?
so that we get one torch2jax library which has different modes that could be enabled either automatically (or via a flag), thus covering difference scenarios and use cases?
I think the easiest path to this future would be to allow end-users to extend torch2jax (this one) for their custom ops. That should do the trick. I don't think it would be all that hard actually, but it's not something that has been explored yet.
Hmm, that error is actually being raised by KeOps and not torch2jax. What happens if you run that example code in float32 instead of float64?
ValueError: [KeOps] float32 data type incompatible with KeOps.
I think the easiest path to this future would be to allow end-users to extend torch2jax (this one) for their custom ops. That should do the trick. I don't think it would be all that hard actually, but it's not something that has been explored yet.
JAX method for creating custom ops is extremely confusing, poorly documented and rather tedious (unlike with Tensorflow or Pytorch). The other torch2jax library handles this seamlessly with minimal code required by the user. Hence my suggestion in regards to exploring combining efforts.
ValueError: [KeOps] float32 data type incompatible with KeOps.
Hmm, something seems fishy here. This seems like a very weird error for KeOps to be throwing.
JAX method for creating custom ops is extremely confusing, poorly documented and rather tedious (unlike with Tensorflow or Pytorch). The other torch2jax library handles this seamlessly with minimal code required by the user. Hence my suggestion in regards to exploring combining efforts.
One option here would be to use different conversion libraries on the different pieces of your model and then merge them once they're both JAX-ified. I'm resistant to bringing in the additional complexity of having two distinct forms of conversion in a single library.
Also, overriding custom PyTorch ops does not necessarily require implementing those ops in custom JAX operators: They could instead be implemented with vanilla JAX/XLA ops. There's pros and cons either way, but at a minimum custom ops on the PyTorch side will require some customization on the torch2jax side. There is not a straightforward way to automagically convert PyTorch custom ops to XLA custom ops AFAIK.
Hmm, something seems fishy here. This seems like a very weird error for KeOps to be throwing.
It works with native PyTorch and it works with the other torch2jax library. So it is definitely something failing specifically with regards to how the conversion is being attempted.
I'm resistant to bringing in the additional complexity of having two distinct forms of conversion in a single library.
Fair enough.
The Keops library (https://www.kernel-operations.io/) has numpy / pytorch bindings. However, unless I am doing something wrong, I don't think torch2jax (as it stands) will support this due to the usage of LazyTensors? I was wondering if there is a way around this?