patrick-kidger / torchcubicspline

Interpolating natural cubic splines. Includes batching, GPU support, support for missing values, evaluating derivatives of the spline, and backpropagation.
Apache License 2.0
198 stars 18 forks source link

evaluating at different time points per batch #17

Open LucaCras opened 4 months ago

LucaCras commented 4 months ago

Hi,

I've got a tensor of the shape batch_size x T x nr_channels and I created a cubic spline accordingly.

Let's say I have created this tensor with shape (2, 10, 1)

Now I want to query [[0], [1]] (e.g. value at t=0 for batch 0 and value at t=1 for batch 1)

I know I can query spline.evaluate(torch.tensor(0)) to get the value at t=0 for both batches, but how can I query the above such that it returns a tensor of shape (2, 1, 1) or (2, 1) as the above query returns (2, 2, 1) or even (2, 2, 1, 1).

patrick-kidger commented 4 months ago

I don't believe this is possible, unfortunately.

I think for this I would recommend using JAX, and in particular the interpolation routines in Diffrax as a better more featureful option.