Open mfschubert opened 9 months ago
I am seeing some issues with super-long compile times in the optimization context, which are eliminated when we use a stop_gradient
before the vector field calculation. I am thinking we should just add this stop_gradient
, and then restore the ability to backpropagate through vector field generation via the method mentioned above. This might be fairly involved, and would take time. fyi @smartalecH
Yep this sounds like a good plan to me. How hard do we anticipate the manual adjoint will be?
I am looking at it a bit. It might actually be relatively straightforward. Here's a reference that seems nice, it even includes Jax code: https://implicit-layers-tutorial.org/implicit_functions/
@smartalecH @Luochenghuang I have things working here---all it needed was a bit of regularization.
I think we may want to put this on hold for now: the potential accuracy improvement is small, and there is a speed penalty.
stop_gradient
in the vector field calculation.mewtax
to solve for the vector fields, but this seems to make the tests much slower (2x time to complete all tests). I suspect there is a significant compile time penalty.
Currently, we directly backpropagate through the tangent vector field calculation, which involves a Newton solve to find the minimum of a convex quadratic objective. It may be more efficient to define a custom gradient for this operation, in a manner similar to what is done for differentiable optimization.