facebookresearch / fmmax

Fourier modal method with Jax
MIT License
97 stars 10 forks source link

Use "differentiable optimization trick" for backpropagation through tangent vector field calculation #74

Open mfschubert opened 9 months ago

mfschubert commented 9 months ago

Currently, we directly backpropagate through the tangent vector field calculation, which involves a Newton solve to find the minimum of a convex quadratic objective. It may be more efficient to define a custom gradient for this operation, in a manner similar to what is done for differentiable optimization.

mfschubert commented 9 months ago

I am seeing some issues with super-long compile times in the optimization context, which are eliminated when we use a stop_gradient before the vector field calculation. I am thinking we should just add this stop_gradient, and then restore the ability to backpropagate through vector field generation via the method mentioned above. This might be fairly involved, and would take time. fyi @smartalecH

smartalecH commented 9 months ago

Yep this sounds like a good plan to me. How hard do we anticipate the manual adjoint will be?

mfschubert commented 9 months ago

I am looking at it a bit. It might actually be relatively straightforward. Here's a reference that seems nice, it even includes Jax code: https://implicit-layers-tutorial.org/implicit_functions/

mfschubert commented 8 months ago

@smartalecH @Luochenghuang I have things working here---all it needed was a bit of regularization.

https://github.com/mfschubert/mewtax

mfschubert commented 8 months ago

I think we may want to put this on hold for now: the potential accuracy improvement is small, and there is a speed penalty.