SciML / DifferentialEquations.jl

Multi-language suite for high-performance solvers of differential equations and scientific machine learning (SciML) components. Ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), differential-algebraic equations (DAEs), and more in Julia.
https://docs.sciml.ai/DiffEqDocs/stable/
Other
2.85k stars 225 forks source link

The end of concrete_solve #610

Closed ChrisRackauckas closed 4 years ago

ChrisRackauckas commented 4 years ago

Let's end concrete_solve. The question of "why do we have concrete_solve" and "how to get rid of concrete_solve" are essentially the same question, so let me layout what we need to fix about solve in order to deprecate concrete_solve:

Adjoints are not safe with interpolation.

This is not something that can or ever will be fixed, because it's just not a good idea. You could add it be adding discrete sensitivity analysis as a way to calculate the derivatives of sol.k w.r.t. parameters, but if you look at what's going on... no. Discrete sensitivity analysis is not only equivalent to AD through the solver, but it would also require an implementation per method. So... you might as well AD through the solver if that's what you're looking for, and that would be much more efficient since adjoints + continuous outputs is essentially calculating two different versions of reverse-mode AD when it could be done with one, so it's just never a good way computationally or memory-wise to compute the derivative there.

If dense=false, then a linear interpolation is used and that would be safe, and currently concrete_solve does not allow this safe option. However, this is not something we could blindly due to the user, since if outside of an AD context they have a 9th order algorithm but when doing AD it's a 1st order algorithm, that would introduce so many numerical issues it's not even funny. So we can't just set dense=false to the user. But if we don't and they do use interpolation, they will get a zero gradient from the values generated by the zero gradient, which then Zygote brings all the way back as zero gradients instead of erroring, and so training is now essentially turned off on loss functions which are dependent on the interpolation and only when dense=true. That is also a major trap, and the big reason why concrete_solve was added in the first place.

Solution

Allow for passing dense = NullInterpolation() and when the solvers see this, they create a post-solution interpolation that errors if you try to use it, saying that this interpolation is not compatible with usage inside of AD, suggesting that you use saveat or resort to AD on the solver itself for this functionality. Downstream packages need to get updated for this. I think most algorithms just pass through dense=dense to build_solution, so it can be handled in DiffEqBase+OrdinaryDiffEq and that should handle everything.

AD pass through

So okay, this brings up the second issue: if we want to make this work out, then we need to have an option so that AD can still work on the solver, otherwise we recursively keep capturing it to send it to an adjoint method.

Solution

This means we need a sensealg choice SensitivityPassThrough where when this is seen, the adjoint continues the original AD call on the adjoint. I think this can be done in ChainRules by having !(SensitivityPassThrough <: AbstractSensitivityAlgorithm) and then only defining the rrule to dispatch on Union{Nothing,AbstractSensitivityAlgorithm}, since then the AD should just ignore the adjoint when it's SensitivityPassThrough (@oxinabox can you confirm?)

Differentiation w.r.t. input fields

Lastly, we need to figure out how to differentiate w.r.t. input fields. concrete_solve specifically does:

concrete_solve(prob,alg,u0=prob.u0,p=prob.p;...)

so that (a) it's easier for us to write dispatches to differentiate w.r.t. u0 and p but also (b) it's easier for users to change u0 and p.

Solution

One more recent change is that we now have a system for allowing overrides:

solve(prob,alg,u0=...,p=...)

I think we just need to have a lowering process that's like solve -> _solve_up (which then gets the adjoint definitions) -> internal stuff for handling distributions and all of that.

@oxinabox is there something special that could help here?

ChrisRackauckas commented 4 years ago

PRs:

https://github.com/SciML/DiffEqBase.jl/pull/520 https://github.com/SciML/DiffEqSensitivity.jl/pull/261 https://github.com/SciML/DiffEqSensitivity.jl/pull/262 https://github.com/SciML/Sundials.jl/pull/265 https://github.com/SciML/OrdinaryDiffEq.jl/pull/1149

oxinabox commented 4 years ago

I think this can be done in ChainRules by having !(SensitivityPassThrough <: AbstractSensitivityAlgorithm) and then only defining the rrule to dispatch on Union{Nothing,AbstractSensitivityAlgorithm}, since then the AD should just ignore the adjoint when it's SensitivityPassThrough (@oxinabox can you confirm?)

Sounds right to me.

@oxinabox is there something special that could help here?

Not sure, are there particular problems you have still? Seems like you have a solution. https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68 would open up some more options, but not nesc useful ones.

ChrisRackauckas commented 4 years ago

https://github.com/SciML/OrdinaryDiffEq.jl/pull/1152

ChrisRackauckas commented 4 years ago

https://github.com/SciML/DiffEqFlux.jl/pull/273

ChrisRackauckas commented 4 years ago

Not sure, are there particular problems you have still? Seems like you have a solution.

Yeah, I got something working. It's a bit complex but it solves all of these problems so we're good.

ChrisRackauckas commented 4 years ago

Done. concrete_solve is no more: solve does it all and is safe.

alexlenail commented 3 years ago

There exists concrete_solve in the latest docs/tutorials, eg.

ChrisRackauckas commented 3 years ago

Thanks fixed.