Closed adam-coogan closed 5 months ago
Hi @adam-coogan ,
A far as I understand, KeOps already handle batch dim:
import torch
from pykeops.torch import LazyTensor
# Details of this function don't matter
def fn(x_i, x_j, y_j, batch=True):
x_i = LazyTensor(x_i)
x_j = LazyTensor(x_j)
y_j = LazyTensor(y_j)
K_ij = (-((x_i - x_j)**2).sum(-1)).exp()
return (K_ij * y_j).sum(batch)
# Batching with KeOps: runs
x_i = torch.randn(5, 10, 1, 2)
x_j = torch.randn(5, 1, 20, 2)
y_j = torch.randn(5, 1, 20, 1)
res_keops = fn(x_i, x_j, y_j)
res_loop = torch.empty(5, 20, 1)
for batch in range(5):
res_loop[batch] = fn(x_i[batch], x_j[batch], y_j[batch], batch=False)
print(torch.allclose(res_keops, res_loop))
If the the behavior of torch.vmap(fn)(x_i, x_j, y_j)
you are expecting is the same as the one that gives res_keops
above... setting generate_vmap_rule=True
does not fix the issue. We have to implement a .vmap()
method that call the standard keops routine. Is that correct ?
Did this ever get resolved?
Hi @adam-coogan , @adam-hartshorne ,
I have implemented the vmap
support in the branch called vmap
. On this branch your test script works now, but in order to get the two computations to be the same, you need to specify a different axis of reduction in the batch case .sum(2)
and in the vmap case .sum(1)
. Here is a script that does the correct comparison :
import torch
from pykeops.torch import LazyTensor
# Details of this function don't matter
def fn(x_i, x_j, y_j, nbatchdims=0):
x_i = LazyTensor(x_i)
x_j = LazyTensor(x_j)
y_j = LazyTensor(y_j)
K_ij = (-((x_i - x_j) ** 2).sum(-1)).exp()
return (K_ij * y_j).sum(nbatchdims + 1)
# Batching with KeOps
x_i = torch.randn(5, 10, 1, 2)
x_j = torch.randn(5, 1, 20, 2)
y_j = torch.randn(5, 1, 20, 1)
res1 = fn(x_i, x_j, y_j, nbatchdims=1)
print(res1.shape)
# Batching with vmap
res2 = torch.vmap(fn)(x_i, x_j, y_j)
print(res2.shape)
print(torch.norm(res1 - res2) / torch.norm(res1))
Note that to implement the vmap method, I just redirected to the batch mode, so it should not make any difference in terms of speed.
If you can try the branch, please let me know if it works as you expect. I will need to update also the other class KernelSolveAutograd
and then I can do the merge.
Support for vmap
is now included in the newly released v2.2 of pykeops, so we can close this issue.
Since KeOps uses
torch.autograd.Function
, it is not compatible by default withvmap
. Here's an example of code that will not run:The relevant part of the stack trace is:
Based on the PyTorch docs, the fix may be as easy as setting
generate_vmap_rule=True
in theFunction
definition. I haven't looked at the KeOps internals to see whether it's that simple, though.It would be great to get this working: KeOps +
vmap
would be quite useful in certain situations!