aliutkus / torchinterp1d

1D interpolation for pytorch
BSD 3-Clause "New" or "Revised" License
162 stars 19 forks source link

Make compatible with torch.vmap #20

Open adam-coogan opened 1 year ago

adam-coogan commented 1 year ago

Since torchinterp1d uses torch.autograd.Function, it is not compatible by default with vmap. Here's an example of code that will not run:

import torch
from torchinterp1d import interp1d

def interpolate(xp):
    x = torch.linspace(-5, 5, 100)
    y = x**3
    return interp1d(x, y, torch.atleast_1d(xp))

xp = torch.rand(20) * 10 - 5
print(f"{xp=}, {torch.vmap(interpolate)(xp)=}")

The relevant part of the stack trace is:

RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.
For more details, please see https://pytorch.org/docs/master/notes/extending.func.html

Based on the PyTorch docs, the fix may be as easy as setting generate_vmap_rule=True in torchinterp1d, but I haven't looked into this yet.

It'd be great to get a fix for this since vmap is incredible useful.

aliutkus commented 1 year ago

hi, I'm a bit away from this right now, but I'd gladly accept a PR