TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.03k stars 218 forks source link

Interface for specifying adjoints of FEM models #2303

Open mhauru opened 4 weeks ago

mhauru commented 4 weeks ago

@llfung talked to me today about his work on adjoint-accelerated programmable inference for large PDEs, and what would be needed on Turing's part to support that. As I understand it (and I know very little about FEM, so bear with me if my explanation/understanding here is poor), using the adjoint allows you to effectively compute the gradient with respect to the parameters of the FEM problem. What they would need is some sort of mechanism in Turing to specify the function computing the gradient of the FEM part, and have MLE/MAP, and maybe also other things like sampling, to make use of that.

My first question for @llfung was whether we could hijack the AD mechanisms and specify an AD rule for the FEM solver, and have the implementation of that rule use the adjoint. Lloyd tells me that adjoint requires knowledge of the surrounding objective function (log density). I don't understand this well enough, but maybe in reverse mode AD we actually have what we need in the case of computing a vector jacobian product?

If that doesn't pan out, maybe we need some sort of new context, for which we can say "when computing a gradient, once you hit the point of calling the FEM solver, use this adjoint thing instead".

I'll leave @llfung to expand on this and to correct all the bits I got wrong.

devmotion commented 4 weeks ago

Naive question, but isn't this solved by SciMLSensitivity and typically the adjoint method used automatically when ADing through solve(...) (with possible manual adjustments) for a DE system? I'm not 100% sure about PDEs but for ODEs, SDEs etc. this should definitely be the case.

llfung commented 4 weeks ago

@devmotion You're right. There's the adjoint_sensitivities interface that tape into the AD. I'll look into that. This tutorial gives a good summary of how it can be used. I will follow up on the mathematical detail of what we aim to do with adjoint in a few weeks time. It might be that we can exploit the adjoint_sensitivities interface in SciMLSensitivity with AD.

llfung commented 4 weeks ago

To illustrate the math of the adjoint problem, let's consider a simple ODE in time $$t \in [t_0,T]$$:

$$ \frac{du}{dt}= f(u,p) $$

with given initial condition

$$ u(t_0) =u_0 $$

where $$u(t,p)$$ is the solution given parameter $$p$$ and user-function $$f(u,p)$$.

The adjoint is used when we try to optimise the loss function (i.e. the log-density in Turing)

$$ G(u,p)=G(u(t,p))=\int{t{0}}^{T}g(u(t,p),p)dt $$

which requires the gradient

$$ \frac{dG}{dp}=\int{t{0}}^{T}\lambda^{\star}(t)f{p}(t)+g{p}(t)dt+\lambda^{\star}(t{0})u{p}(t_{0}) $$

computed using the Lagrange multiplier $\lambda(t)$ found through backward integration of

$$ \frac{d\lambda^{\star}}{dt}=g{u}(u(t,p),p)-\lambda^{\star}(t)f{u}(t,u(t,p),p),\thinspace\thinspace\thinspace\lambda^{\star}(T)=0 $$

Therefore, to use the adjoint method, we need access to dgdu= $$g{u}(u(t,p),p)$$ and dgdp = $$g{p}(t)$$. Currently, Turing only specify $$G$$, and expect AD works out $$\frac{dG}{dp}$$ in the background. Adjoint helps accelerate the computation of $$\frac{dG}{dp}$$ by hijacking AD, but requires more information regarding dgdu= $$g{u}(u(t,p),p)$$ and dgdp = $$g{p}(t)$$.

llfung commented 4 weeks ago

Did a bit of digging around [SciMLSensitivity](https://docs.sciml.ai/SciMLSensitivity/):

  1. There are two level of algorithms in SciML Sensitivity. The lower level algorithm adjoint_sensitivities takes the gradient of the objective/cost/loss function g against the solution u (i.e. dgdu) and dgdp as the inputs, solve the adjoint problem, and gives $$dG/dp$$.
  2. The higher level algorithm does not require specific input, other than specifying that adjoint shall be used in sensealg. When the gradient is required, the Sensitivity algorithm will kick in automatically to hijack the AD process, and use the adjoint method. It works out the dgdu and dgdp function automatically using AD, and call adjoint_sensitivities automatically.
  3. adjoint_sensitivities is designed specifically to work with the internals of DifferentialEquations or NonlinearProblem or other SciML solvers, and it therefore not compatible with arbitrary solver. To use adjoint_sensitivities, the forward problem must be cast using solvers within the SciML ecosystem. As for the higher level algorithm that automatically finds dgdu and dgdp, we are not sure if it's specific to SciML or if it works generally.

For the purpose of our project, we'll likely use FEM solver outside the SciML ecosystem. FEM solver will provide the equivalent of adjoint_sensitivities, in which we will take dgdu and dgdp, solve the adjoint problem, and output $$dG/dp$$.

Therefore, there are two ways to implement the adjoint acceleration with our custom FEM solver.

  1. We overload adjoint_sensitivities in SciMLSensitivity with our own solver. This way, we can exploit the higher level algorithm in SciMLSensitivity to work out the dgdu and dgdp for us.
  2. Since Turing is specifying $$G$$ anyway, we can implement within Turing the analytical form of dgdu and dgdp and call the FEM adjoint directly.

Method 1 is more complete and works better with the whole SciML ecosystem, but it will require close collaboration with the SciMLSensitivity project. It will be a harder engineering challenge. Method 2 is more specific to what we need, easier to implement, but slightly detached from the SciML ecosystem.

yebai commented 4 weeks ago

Cc @ChrisRackauckas

ChrisRackauckas commented 4 weeks ago

You don't have to do any of that. You're overthinking it. If you use forward/reverse mode AD on a solve call, it will automatically use forward/adjoint sensitivities. Thus if you stick a solve call in Turing, it will use adjoints automatically with no other work required on your end. That's mentioned in the tutorial:

https://turing.ml/v0.22/tutorials/10-bayesian-differential-equations/#scaling-to-large-models-adjoint-sensitivities

Now the next thing is extending to PDEs. PDE discretizations like FEM are simply just transformations of PDEs into computable forms. These computable forms are the SciMLBase interfaces, such as LinearProblem, NonlinearProblem, ODEProblem, etc. For example, a method of lines discretization which leaves time intact will give you an ODEProblem, while a finite element collocation in time gives you a NonlinearProblem. It does not matter how you do your PDE discretization, you end up with one of the canonical mathematical problems to solve for the coefficients of the representation. In that case, applying the adjoint method then follows automatically, since you perform your discretization, then build a LinearProblem/NonlinearProblem/ODEProblem, and solve it, and then again when Turing automatic differentiation applies, it will automatically know (given the size of the problem and other heuristics) to apply adjoint differentiation to the LinearProblem/NonlinearProblem/ODEProblem solve.

That means that the only thing you have to do in order to make this work out is to ensure that your FEM discretization, i.e. the matrix assembly, is compatible with automatic differentiation. Preferably reverse mode. If you handle that, all of the other adjoint rules then follow, and all of the implicit differentiation tricks are then applied automatically behind the scenes via other rules definitions.

ChrisRackauckas commented 4 weeks ago

And for reference, we have some projects using Ferrite.jl which seems to be Enzyme compatible, so if Turing can use Enzyme these days then using Ferrite for the semi-discretization would be a good optimization. We should probably make a tutorial along these lines.

llfung commented 3 weeks ago

Thanks @ChrisRackauckas . Yes, we understand that if everything is done under SciML's solve call, then adjoint is invoked automatically. We also understand that FEM can be casted into LinearProblem or NonlinearProblem. The main point is to create an interface for solvers that sit outside the SciML ecosystem, such as FEniCS.jl and GridAP.jl or other custom FEM solver we have, which may or may not be compatible with AD.

ChrisRackauckas commented 3 weeks ago

Only the semi-discretization would have to be AD compatible, then it would flow. That just takes a few adjoint overloads. Someone tested that back in like 2020 for FEniCS.jl. Ferrite.jl should be directly compatible with Enzyme.