getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Incompatible with vmap #303

Closed adam-coogan closed 5 months ago

adam-coogan commented 1 year ago

Since KeOps uses torch.autograd.Function, it is not compatible by default with vmap. Here's an example of code that will not run:

import torch
from pykeops.torch import LazyTensor

# Details of this function don't matter
def fn(x_i, x_j, y_j):
    x_i = LazyTensor(x_i)
    x_j = LazyTensor(x_j)
    y_j = LazyTensor(y_j)
    K_ij = (-((x_i - x_j)**2).sum(-1)).exp()
    return (K_ij * y_j).sum(1)

# Batching with KeOps: runs
x_i = torch.randn(5, 10, 1, 2)
x_j = torch.randn(5, 1, 20, 2)
y_j = torch.randn(5, 1, 20, 1)
print(fn(x_i, x_j, y_j))

# Batching with vmap: raises exception
print(torch.vmap(fn)(x_i, x_j, y_j))

The relevant part of the stack trace is:

RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.
For more details, please see https://pytorch.org/docs/master/notes/extending.func.html

Based on the PyTorch docs, the fix may be as easy as setting generate_vmap_rule=True in the Function definition. I haven't looked at the KeOps internals to see whether it's that simple, though.

It would be great to get this working: KeOps + vmap would be quite useful in certain situations!

bcharlier commented 1 year ago

Hi @adam-coogan ,

A far as I understand, KeOps already handle batch dim:

import torch
from pykeops.torch import LazyTensor

# Details of this function don't matter
def fn(x_i, x_j, y_j, batch=True):
    x_i = LazyTensor(x_i)
    x_j = LazyTensor(x_j)
    y_j = LazyTensor(y_j)
    K_ij = (-((x_i - x_j)**2).sum(-1)).exp()
    return (K_ij * y_j).sum(batch)

# Batching with KeOps: runs
x_i = torch.randn(5, 10, 1, 2)
x_j = torch.randn(5, 1, 20, 2)
y_j = torch.randn(5, 1, 20, 1)

res_keops = fn(x_i, x_j, y_j)

res_loop = torch.empty(5, 20, 1)
for batch in range(5):
    res_loop[batch] = fn(x_i[batch], x_j[batch], y_j[batch], batch=False)

print(torch.allclose(res_keops, res_loop))

If the the behavior of torch.vmap(fn)(x_i, x_j, y_j) you are expecting is the same as the one that gives res_keops above... setting generate_vmap_rule=True does not fix the issue. We have to implement a .vmap() method that call the standard keops routine. Is that correct ?

adam-hartshorne commented 1 year ago

Did this ever get resolved?

joanglaunes commented 1 year ago

Hi @adam-coogan , @adam-hartshorne , I have implemented the vmap support in the branch called vmap. On this branch your test script works now, but in order to get the two computations to be the same, you need to specify a different axis of reduction in the batch case .sum(2) and in the vmap case .sum(1). Here is a script that does the correct comparison :

import torch
from pykeops.torch import LazyTensor

# Details of this function don't matter
def fn(x_i, x_j, y_j, nbatchdims=0):
    x_i = LazyTensor(x_i)
    x_j = LazyTensor(x_j)
    y_j = LazyTensor(y_j)
    K_ij = (-((x_i - x_j) ** 2).sum(-1)).exp()
    return (K_ij * y_j).sum(nbatchdims + 1)

# Batching with KeOps
x_i = torch.randn(5, 10, 1, 2)
x_j = torch.randn(5, 1, 20, 2)
y_j = torch.randn(5, 1, 20, 1)
res1 = fn(x_i, x_j, y_j, nbatchdims=1)
print(res1.shape)

# Batching with vmap
res2 = torch.vmap(fn)(x_i, x_j, y_j)
print(res2.shape)

print(torch.norm(res1 - res2) / torch.norm(res1))

Note that to implement the vmap method, I just redirected to the batch mode, so it should not make any difference in terms of speed. If you can try the branch, please let me know if it works as you expect. I will need to update also the other class KernelSolveAutograd and then I can do the merge.

joanglaunes commented 5 months ago

Support for vmap is now included in the newly released v2.2 of pykeops, so we can close this issue.