patrick-kidger / torchcubicspline

Interpolating natural cubic splines. Includes batching, GPU support, support for missing values, evaluating derivatives of the spline, and backpropagation.
Apache License 2.0
198 stars 18 forks source link

Applying on typical PyTorch data format #5

Open fshamsafar opened 3 years ago

fshamsafar commented 3 years ago

Hi,

Is it possible to apply the cubic spline interpolation on one dimension of a data with the data format as in PyTorch [BatchSize, Channel, Height, Width]?

Thanks

patrick-kidger commented 3 years ago

That, is you want to interpolate along e.g. the height dimension? Sure:

data = torch.rand(batch, channel, height, width)
data_perm = data.transpose(1, 3)  # shape (batch, width, height, channel)

Then do interpolation. Batch and width both get treated as batch dimensions, and interpolation is performed along the height dimension. Once you've sampled from the cubic spline, transpose the dimensions back again.

fshamsafar commented 3 years ago

Actually, I want to do the interpolation along the channel dimension. Namely, the input data is [B, C, H, W] and I wish to get non-integer values at 1/4 integer values of channel dimension, [B, C/4, H, W]. I am confused how I should apply the natural_cubic_spline_coeffs function after I permute the dimensions to [B, H, C/4, W].

patrick-kidger commented 3 years ago

Along the channel dimension? Sounds a little odd, but sure.

natural_cubic_spline_coeffs takes in tensors of shape (..., X, Y), corresponding to a (batch of) sequences of length X, with each sequence element being a vector of size Y. There's actually no difference between the ... dimensions and the Y dimension; they're both batch dimensions as far as interpolation is concerned. This Y dimension is put at the end just as a convenience, to fit the common "batch-length-channel" data format.

So in this case, your channel dimension corresponds to X, and all of the other dimensions should be spread amongst the ... or the Y. Which ones you want to put where are up to you. If you want to treat them all consistently you could permute them to [B, H, W, C, 1] with a dummy dimension at the end; or you could do [B, H, C, W] as you suggest. Either way just call directly natural_cubic_spline_coeffs on that.

Does that make sense?

fshamsafar commented 3 years ago

Great! Thanks! So, basically but a snippet of code similar to the one below, I guess I will get all the data information at 1/4 of the channel dimension. Am I right?!

batch = 2
channel = 128
height = 512
width = 512
x = torch.rand(batch, height, channel, width)
t = torch.linspace(0, channel/4, channel)
coeffs = natural_cubic_spline_coeffs(t, x)
spline = NaturalCubicSpline(coeffs)
out = spline.derivative(t)
patrick-kidger commented 3 years ago

More or less. The final call to spline.derivative will be evaluated at all the t points, and those were also specified as being the times of the knots in natural_cubic_spline_coeffs, so the values you get out will be the derivative at all of your inputs x.

I don't know if that aligns with what you're after exactly. Maybe create natural_cubic_spline_coeffs with torch.linspace(0, channels-1, channels), and then evaluating at torch.linspace(0, channels-1, 4*channels) is what you're after? That will evaluate at every quarter-channel.