SciML / NonlinearSolve.jl

High-performance and differentiation-enabled nonlinear solvers (Newton methods), bracketed rootfinding (bisection, Falsi), with sparsity and Newton-Krylov support.
https://docs.sciml.ai/NonlinearSolve/stable/
MIT License
235 stars 41 forks source link

Adjoint overload #29

Closed ChrisRackauckas closed 1 year ago

ChrisRackauckas commented 3 years ago

It's just the implicit function theorem

https://github.com/mitmath/18335/blob/spring20/notes/adjoint/adjoint.pdf

While there is a generic one for SteadyStateProblem in DiffEqSensitivity (that could be extended to NonlinearProblem). https://github.com/SciML/DiffEqSensitivity.jl/pull/389 will work on anything with __solve

YingboMa commented 3 years ago

Currently, NonlinearSolve doesn't work with Zygote due to build_solution.

julia> using ModelingToolkit, Zygote

julia> Zygote.gradient((u, p)->ModelingToolkit.StructuralTransformations.numerical_nlsolve((u,p)->hypot(u, p)-cos(u), u, p), 0.1, 0.2)
ERROR: MethodError: no method matching ndims(::Tuple{Float64})
Closest candidates are:
  ndims(::AbstractAlgebra.MatrixElem{T} where T) at /Users/scheme/.julia/packages/AbstractAlgebra/Boo1X/src/generic/Matrix.jl:441
  ndims(::Base.Iterators.ProductIterator) at iterators.jl:967
  ndims(::Base.Generator) at generator.jl:53
  ...
Stacktrace:
  [1] build_solution(prob::NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, alg::NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, u::Tuple{Float64}, resid::Float64; calculate_error::Bool, retcode::Symbol, original::Nothing, left::Nothing, right::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ SciMLBase ~/src/julia/SciMLBase/src/solutions/nonlinear_solutions.jl:26
  [2] build_solution(prob::NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, alg::NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, u::Tuple{Float64}, resid::Float64)
    @ SciMLBase ~/src/julia/SciMLBase/src/solutions/nonlinear_solutions.jl:25
  [3] (::DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float64, 0, Float64, Float64, NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, Nothing, Nothing}})(Δ::Float64)
    @ DiffEqBase ~/src/julia/DiffEqBase/src/zygote.jl:33
  [4] (::DiffEqBase.var"#150#back#172"{DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float64, 0, Float64, Float64, NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, Nothing, Nothing}}})(Δ::Float64)
    @ DiffEqBase ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [5] Pullback
    @ ~/src/julia/ModelingToolkit/src/structural_transformation/utils.jl:308 [inlined]
  [6] (::typeof(∂(numerical_nlsolve)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [7] Pullback
    @ ./REPL[60]:1 [inlined]
  [8] (::typeof(∂(#43)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [9] (::Zygote.var"#41#42"{typeof(∂(#43))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [10] gradient(::Function, ::Float64, ::Vararg{Float64, N} where N)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [11] top-level scope
    @ REPL[60]:1

julia>

@DhairyaLGandhi could you take a look?

ChrisRackauckas commented 3 years ago

using DiffEqSensitivity?

YingboMa commented 3 years ago

How is DiffEqSensitivity related? This is nonlinear system.

ChrisRackauckas commented 3 years ago

It should use the SteadyStateProblem adjoint IIRC.

ChrisRackauckas commented 3 years ago

https://github.com/SciML/DiffEqSensitivity.jl/blob/dfa7ba71909a7fcee2de327a324512b8e83b420b/src/concrete_solve.jl#L26-L30

YingboMa commented 3 years ago

I derived frule and rrule here: https://gist.github.com/YingboMa/4e4496f828c6a3179004f6d0ca224d2a

Someone just need to write a performant implementation of it.

ChrisRackauckas commented 3 years ago

What's wrong with the current implementation of the vjp?

YingboMa commented 3 years ago

For performance reasons, the adjoint for numerical_nlsolve should just be a dozen lines of code that's non-allocating.

ChrisRackauckas commented 3 years ago

I see so just the small problem issue so it'll need a specialized form?

YingboMa commented 3 years ago

Yeah, exactly.