Since torchinterp1d uses torch.autograd.Function, it is not compatible by default with vmap. Here's an example of code that will not run:
import torch
from torchinterp1d import interp1d
def interpolate(xp):
x = torch.linspace(-5, 5, 100)
y = x**3
return interp1d(x, y, torch.atleast_1d(xp))
xp = torch.rand(20) * 10 - 5
print(f"{xp=}, {torch.vmap(interpolate)(xp)=}")
The relevant part of the stack trace is:
RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.
For more details, please see https://pytorch.org/docs/master/notes/extending.func.html
Based on the PyTorch docs, the fix may be as easy as setting generate_vmap_rule=True in torchinterp1d, but I haven't looked into this yet.
It'd be great to get a fix for this since vmap is incredible useful.
Since
torchinterp1d
usestorch.autograd.Function
, it is not compatible by default withvmap
. Here's an example of code that will not run:The relevant part of the stack trace is:
Based on the PyTorch docs, the fix may be as easy as setting
generate_vmap_rule=True
intorchinterp1d
, but I haven't looked into this yet.It'd be great to get a fix for this since
vmap
is incredible useful.