f0uriest / interpax

Interpolation and function approximation with JAX
MIT License
137 stars 11 forks source link

Avoid recompilation after model surgery #36

Closed ToshiyukiBandai closed 4 months ago

ToshiyukiBandai commented 4 months ago

Hi @f0uriest Thank you so much for sharing this amazing library! It makes my life easier. I have a question regarding jit compilation of the interpolated function (generated by Interpolator1D for example). In my problem, I want to update the coefficient of the interpolation function by a model surgery using equinox:

interpolated_fun = eqx.tree_at(lambda m: m.f, interpolated_fun , new_coefficient)

The thing is this 'interpolated_fun' is embedded in a large training step. If I updated the interpolation function in this way, the updated is not reflected expectedly, and I would have to redefine the whole training step, which is not an option. Do you have any suggestions for this time of thing? What I am thinking is to just implement the interpolated function as a normal function and treat the coefficients as an argument for the function.

f0uriest commented 4 months ago

Interpolator1D generally assumes that the values being interpolated are fixed, and hence it precomputes the slopes needed for interpolation. If you update f, it will still use the old slopes which is wrong.

I think what you want is interpax.interp1d which is basically a functional interface like what you describe (under the hood Interpolator1D.__call__ is just calling interp1d with some cached values).

ToshiyukiBandai commented 4 months ago

Thank you! I will test it and get back to you.

ToshiyukiBandai commented 4 months ago

@f0uriest Yes, it worked well. Thank you so much!