patrick-kidger / torchcubicspline

Interpolating natural cubic splines. Includes batching, GPU support, support for missing values, evaluating derivatives of the spline, and backpropagation.
Apache License 2.0
198 stars 18 forks source link

Switch to pytorch native search #6

Closed jhrmnn closed 3 years ago

jhrmnn commented 3 years ago

I actually don't know what it does to speed, though I assume it won't get slower. I was motivated by the memory, which the previous approach required an unnecessary amount because of the broadcasting.

patrick-kidger commented 3 years ago

Thanks for the PR!

If you can make those tweaks then I'd be happy to accept this PR.

jhrmnn commented 3 years ago

Done.

Btw, searchsorted and bucketize calls to the same function under the hood.

https://github.com/pytorch/pytorch/blob/2ecb2c79312eac6924182fe343dc3cd9a205305c/aten/src/ATen/native/cuda/Bucketization.cu

patrick-kidger commented 3 years ago

Excellent, LGTM!