patrick-kidger / torchcubicspline

Interpolating natural cubic splines. Includes batching, GPU support, support for missing values, evaluating derivatives of the spline, and backpropagation.
Apache License 2.0
198 stars 18 forks source link

Speed of tridiagonal_solve() #7

Open jhrmnn opened 3 years ago

jhrmnn commented 3 years ago

Hi. As you warn in the code, tridiagonal_solve() is quite slow. I've compared to plain torch.solve(), which is much faster, so I'll be using that in my application, but I was wondering if you are interested in a patch, or perhaps you had other reasons to use the Thomas algorithm.

patrick-kidger commented 3 years ago

So I frequently use this on relatively long sequences. Linear scaling in time is good for peace of mind, and linear scaling in memory might be a necessity to have it work in memory.

That's just theoretically speaking, though. I've not run time or memory benchmarks against torch.solve. If you can demonstrate that it's more time efficient, and not too memory inefficient, then I'd be happy to accept a patch that dispatches to torch.solve in the regime for which those are true. (Or if you prefer, just an argument to switch from one to the other, as I appreciate that'll be less hassle to put together.)