Closed adam-hartshorne closed 1 year ago
I managed to reproduce it, thanks for submitting it!
It seems keops is not entirely compatible with the PyTorch functional interface (the error message seems to suggest it is a missing implementation in keops).
I made changes to allow to user of torch2jax
to default to the old torch.autograd.grad
interface for taking gradients, in your emaple, the only change would be
def jax_keops_rbf_loss(x, y):
loss = torch2jax_with_vjp(
keops_rbf_loss,
jax.ShapeDtypeStruct(x.shape, dtype=x.dtype),
jax.ShapeDtypeStruct(y.shape, dtype=y.dtype),
output_shapes=jax.ShapeDtypeStruct((1,), dtype=x.dtype),
depth=2,
use_torch_vjp=False
)(x, y)
return loss
Notice, the use_torch_vjp=False
flag.
This switches torch2jax_with_vjp
gradient definition from using torch.func.vjp
to torch.autograd.grad
(where the latter, importantly, is not compatible with torch.vmap
, so it also instructs torch2jax
to switch to dumb experimental looping instead of using torch.vmap
; this should not be any slower as long as jax.vmap
is not used over the wrapped function, something which your code doesn't do anyway).
Let me know if this works for you!
Thank you for such a swift response.
Also thank you for mentioning the potential torch.vmap / jax.vmap issue, as the current structure of my real code (in which I want to replace a JAX function with a Keops function) uses this.
It might be worth adding this issue to the docs, as I wouldn't be surprised if quite a few libraries that interface with torch still use the old torch.autograd.grad method for handling gradients.
No worries, I'm glad you're finding the package useful.
After reading through keops, they're defining a custom autograd function in the old way (using forward
& backward
instead of forward
, backward
, and setup_context
). I experimented a bit with editing that code and it's possible to make it work with the new version (and hence torch.vjp
), but I'm not sure if my edits don't break higher derivatives.
That's a good idea, for now, I made the code detect a failure when using torch.vjp
and fallback to torch.autograd.grad
automatically with a warning. I'll look into how to add a description of this pitfall in the documentation.
They might be happy to accept a pull request for modernising their gradients.
I am currently experiencing a slight issue with intermittent NaN gradients in my code when using torch / Keops vs pure JAX. However, it is only occurring when using much larger batch sizes than possible with JAX (because of the lazy evaluation functionality of Keops). I haven't managed to narrow it down any further, so I can't say for absolute certain where the issue lies i.e. Keops, torch2jax or something in my own JAX code.
I will obviously raise an issue here if I believe it is something to do with torch2jax.
I'll look into a pull request once I have a little more time, it's a good idea.
I recently spotted a possible synchronization bug in my CUDA implementation, I already pushed a fix, but it's in the C++ part of the code, so you can either try:
from torch2jax import compile_and_import_module
compile_and_import_module(force_recompile=True)
or delete ~/.cache/torch2jax
, sorry for that.
Thanks, keep me updated!
Ok, so I have discovered the bug that is causing the NaN, but it is very sporadic which makes debugging it hard, but I think there is some sort of weird interaction between torch2jax, torch, keops.
Basically performing the simple division inside the pytorch part of the code (even if it is 1.0 / 1.0) can cause these NaN gradients.
**THIS CAN RESULT IN NAN GRADIENTS**
def batch_keops_ls(x, y, lengthscale):
B, N, D = x.shape # Batch size, number of source points, features
_, M, _ = y.shape # Batch size, number of target points, features
gamma = 1.0 / length_scale
# Encode as symbolic tensors:
x_i = LazyTensor(x.view(B, N, 1, D)) # (B, N, 1, D)
y_j = LazyTensor(y.view(B, 1, M, D)) # (B, 1, M, D)
D2 = x_i.sqdist(y_j)
K = (-D2 * gamma).exp()
return torch.reshape(torch.sum(K.sum(2), dim=1), (-1,))
length_scale = jnp.array([1.0])
xy_cost = torch2jax_with_vjp(
batch_keops_ls,
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(y.shape, x.dtype),
jax.ShapeDtypeStruct(length_scale.shape, x.dtype),
output_shapes=jax.ShapeDtypeStruct((x.shape[0],), x.dtype),
depth=2,
use_torch_vjp=False)(x, y, length_scale)
**THIS WORKS**
def batch_keops_gamma(x, y, gamma):
B, N, D = x.shape # Batch size, number of source points, features
_, M, _ = y.shape # Batch size, number of target points, features
# Encode as symbolic tensors:
x_i = LazyTensor(x.view(B, N, 1, D)) # (B, N, 1, D)
y_j = LazyTensor(y.view(B, 1, M, D)) # (B, 1, M, D)
D2 = x_i.sqdist(y_j)
K = (-D2 * gamma).exp()
return torch.reshape(torch.sum(K.sum(2), dim=1), (-1,))
gamma = jnp.array([1.0])
xy_cost = torch2jax_with_vjp(
batch_keops_gamma,
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(y.shape, x.dtype),
jax.ShapeDtypeStruct(gamma.shape, x.dtype),
output_shapes=jax.ShapeDtypeStruct((x.shape[0],), x.dtype),
depth=2,
use_torch_vjp=False)(x, y, gamma)
**THIS CAN RESULT IN NAN GRADIENTS**
def torch_wrapper(x, y, length_scale):
gamma = 1.0 / length_scale
return batch_keops_gamma(x, y, gamma)
length_scale = jnp.array([1.0])
xy_cost = torch2jax_with_vjp(
torch_wrapper,
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(y.shape, x.dtype),
jax.ShapeDtypeStruct(length_scale.shape, x.dtype),
output_shapes=jax.ShapeDtypeStruct((x.shape[0],), x.dtype),
depth=2,
use_torch_vjp=False)(x, y, length_scale)
**THIS WORKS**
length_scale = jnp.array([1.0])
gamma = 1.0 / length_scale
xy_cost = torch2jax_with_vjp(
batch_keops_gamma,
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(y.shape, x.dtype),
jax.ShapeDtypeStruct(gamma.shape, x.dtype),
output_shapes=jax.ShapeDtypeStruct((x.shape[0],), x.dtype),
depth=2,
use_torch_vjp=False)(x, y, gamma)
I made two bug fixes thanks to #4, let me know if those fixes also help with this please.
Looks good from my end!
Really appreciate how fast you are on correcting these bugs.
I have attached minimal example code below which should allow you to reproduce the issue along with the error log.
Everything works ok unless I attempt to utilise a function that contains calls to KeOps library inside your torch2jax framework and I need gradient i.e. when running an optimiser. The same function calls works fine when simply called from JAX and obviously everything works fine in pure pytorch.
torch2jax_error.txt