Closed ChrisRackauckas closed 1 year ago
Currently, NonlinearSolve
doesn't work with Zygote due to build_solution
.
julia> using ModelingToolkit, Zygote
julia> Zygote.gradient((u, p)->ModelingToolkit.StructuralTransformations.numerical_nlsolve((u,p)->hypot(u, p)-cos(u), u, p), 0.1, 0.2)
ERROR: MethodError: no method matching ndims(::Tuple{Float64})
Closest candidates are:
ndims(::AbstractAlgebra.MatrixElem{T} where T) at /Users/scheme/.julia/packages/AbstractAlgebra/Boo1X/src/generic/Matrix.jl:441
ndims(::Base.Iterators.ProductIterator) at iterators.jl:967
ndims(::Base.Generator) at generator.jl:53
...
Stacktrace:
[1] build_solution(prob::NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, alg::NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, u::Tuple{Float64}, resid::Float64; calculate_error::Bool, retcode::Symbol, original::Nothing, left::Nothing, right::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ SciMLBase ~/src/julia/SciMLBase/src/solutions/nonlinear_solutions.jl:26
[2] build_solution(prob::NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, alg::NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, u::Tuple{Float64}, resid::Float64)
@ SciMLBase ~/src/julia/SciMLBase/src/solutions/nonlinear_solutions.jl:25
[3] (::DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float64, 0, Float64, Float64, NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, Nothing, Nothing}})(Δ::Float64)
@ DiffEqBase ~/src/julia/DiffEqBase/src/zygote.jl:33
[4] (::DiffEqBase.var"#150#back#172"{DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float64, 0, Float64, Float64, NonlinearProblem{Float64, false, Float64, NonlinearFunction{false, var"#44#46", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, NonlinearSolve.NewtonRaphson{12, true, DataType, NonlinearSolve.DefaultLinSolve}, Nothing, Nothing}}})(Δ::Float64)
@ DiffEqBase ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[5] Pullback
@ ~/src/julia/ModelingToolkit/src/structural_transformation/utils.jl:308 [inlined]
[6] (::typeof(∂(numerical_nlsolve)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[7] Pullback
@ ./REPL[60]:1 [inlined]
[8] (::typeof(∂(#43)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[9] (::Zygote.var"#41#42"{typeof(∂(#43))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
[10] gradient(::Function, ::Float64, ::Vararg{Float64, N} where N)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
[11] top-level scope
@ REPL[60]:1
julia>
@DhairyaLGandhi could you take a look?
using DiffEqSensitivity
?
How is DiffEqSensitivity
related? This is nonlinear system.
It should use the SteadyStateProblem adjoint IIRC.
I derived frule and rrule here: https://gist.github.com/YingboMa/4e4496f828c6a3179004f6d0ca224d2a
Someone just need to write a performant implementation of it.
What's wrong with the current implementation of the vjp?
For performance reasons, the adjoint for numerical_nlsolve
should just be a dozen lines of code that's non-allocating.
I see so just the small problem issue so it'll need a specialized form?
Yeah, exactly.
It's just the implicit function theorem
https://github.com/mitmath/18335/blob/spring20/notes/adjoint/adjoint.pdf
While there is a generic one for SteadyStateProblem in DiffEqSensitivity (that could be extended to NonlinearProblem). https://github.com/SciML/DiffEqSensitivity.jl/pull/389 will work on anything with __solve