FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 209 forks source link

Error when differentiating solutions of ModelingToolkit models #1325

Open Antomek opened 1 year ago

Antomek commented 1 year ago

Package Version

Zygote v0.6.49, ModelingToolkit v8.29.1

Julia Version

Julia Version 1.8.2

OS / Environment

macOS Monterey 12.6

Describe the bug

Recently, my scripts I was using to try and get derivatives of ODE solutions from models created with ModelingToolkit.jl have failed with the error:

ERROR: LoadError: Compiling Tuple{Type{Dict}, Base.Iterators.Zip{Tuple{Vector{Sym{Real, Base.ImmutableDict{DataType, Any}}}, Vector{Float64}}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

which is related to this issue, I imagine.

Steps to Reproduce

MWE:

using DifferentialEquations, ModelingToolkit

function lotka_volterra(;name=name)
    states = @variables x(t)=1.0 y(t)=1.0
    params = @parameters p1=1.5 p2=1.0 p3=3.0 p4=1.0

    eqs = [
    D(x) ~ p1 * x - p2 * x * y,
    D(y) ~ -p3 * y + p4 * x * y
    ]

    return ODESystem(eqs, t, states, params; name = name)
end

@named lotka_volterra_sys = lotka_volterra()

prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), [])
sol = solve(prob,Tsit5(),reltol=1e-6,abstol=1e-6)

using Zygote, SciMLSensitivity

function sum_of_solution(u0,p)
    _prob = remake(prob,u0=u0,p=p)
    sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1, sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())))
end

u0 = [1.0 1.0]
p = [1.5 1. 1. 1.]
du01,dp1 = Zygote.gradient(sum_of_solution,u0,p)

Expected Results

I expected to get the gradient of the sum w.r.t. u0 and p.

Observed Results

I get the following error:

ERROR: Compiling Tuple{Type{Dict}, Base.Iterators.Zip{Tuple{Vector{Sym{Real, Base.ImmutableDict{DataType, Any}}}, Vector{Float64}}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:121
  [3] #Primal#23
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:205 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/reverse.jl:330
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/emit.jl:101
  [6] #s2924#1068
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:28 [inlined]
  [7] var"#s2924#1068"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
  [9] _pullback
    @ ~/.julia/packages/SciMLBase/m11uN/src/utils.jl:477 [inlined]
 [10] _pullback(::Zygote.Context{false}, ::typeof(SciMLBase.mergedefaults), ::Dict{Any, Any}, ::Vector{Float64}, ::Vector{Sym{Real, Base.ImmutableDict{DataType, Any}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [11] _pullback
    @ ~/.julia/packages/SciMLBase/m11uN/src/remake.jl:57 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::SciMLBase.var"##remake#527", ::Missing, ::Matrix{Float64}, ::Missing, ::Matrix{Float64}, ::Missing, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#465"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fe24678, 0xefbf7ae3, 0x14077d65, 0xd38b0358, 0xca1226cf)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x2b3c6fd1, 0xac2f72a0, 0xdcafd855, 0x30fb2acf, 0x61128de0)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#488#generated_observed#472"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/.julia/packages/SciMLBase/m11uN/src/remake.jl:45 [inlined]
 [14] _pullback(::Zygote.Context{false}, ::SciMLBase.var"#remake##kw", ::NamedTuple{(:u0, :p), Tuple{Matrix{Float64}, Matrix{Float64}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#465"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fe24678, 0xefbf7ae3, 0x14077d65, 0xd38b0358, 0xca1226cf)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x2b3c6fd1, 0xac2f72a0, 0xdcafd855, 0x30fb2acf, 0x61128de0)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#488#generated_observed#472"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./REPL[15]:2 [inlined]
 [16] _pullback(::Zygote.Context{false}, ::typeof(sum_of_solution), ::Matrix{Float64}, ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [17] pullback(::Function, ::Zygote.Context{false}, ::Matrix{Float64}, ::Vararg{Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:44
 [18] pullback(::Function, ::Matrix{Float64}, ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:42
 [19] gradient(::Function, ::Matrix{Float64}, ::Vararg{Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96
 [20] top-level scope
    @ REPL[18]:1

Relevant log output

No response

ToucheSir commented 1 year ago

Looks like this was caused by a change on the SciML side: https://github.com/SciML/SciMLBase.jl/commit/62494689ad8f198fb75b7170aeb56f837b4590e3. You could ask them to implement the suggestion in https://github.com/FluxML/Zygote.jl/issues/1293#issuecomment-1243051361 and see if that works.

cefitzg commented 1 month ago

@Antomek, do you find a way around this? I need to get Zygote.hessian of a loss function involving forward simulations of ODE defined using ModelingToolkit and ran into the same error.