Open odstrcilt opened 1 month ago
Do you have an example of the code above being used compared to interpax? Also, are you running on CPU or GPU? On CPU the tridiagonal solve is likely faster, but on GPU the loops will cause a lot of overhead compared to just calling out to cusolver on a full matrix, though if that's the case I would expect to see a performance difference in the forward pass not just in the gradient.
Interfax is a very nice package. However, if it is included in the code wrapped in jax.grad, it is about 10x slower than my simple cubic spline interpolation. Interpolation itself is not significantly different. Code was written with the help of chatGPT. It is not a high-priority issue.