SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
330 stars 71 forks source link

[FR] Gradient calculation does not work for integrator interface with `init` and `SavingCallback`. #613

Open JinraeKim opened 2 years ago

JinraeKim commented 2 years ago

Hi, my goal is described in SciML/DiffEqFlux.jl#661, and I tried to narrow down the whole problem into a small one to determine whether it's possible to take backpropagation with integrator interface.

The following example is in this commit, branch iss11, SimulationLogger.jl. I found that gradient from Zygote.jl does not work with the integrator constructed by init, while we can take gradient if the output is changed as prob.p (where prob |> typeof <: ODEProblem).

What would be the best workaround for this case? Or, if this is a new feature, I would like to request this new feature for flexible usage of DiffEqFlux.jl.

p = 2.0 function main(p) @Loggable function dynamics!(dx, x, p, t) @log x dx .= -p*x end

if hasmethod(dynamics!, Tuple{Any, Any, Any, Any, __LOG_INDICATOR__})

    # to avoid undefined error when not adding @Loggable
    log_func(x, t, integrator::DiffEqBase.DEIntegrator; kwargs...) = feedback_dynamics!(zero.(x), copy(x), integrator.p, t, __LOG_INDICATOR__(); kwargs...)
    x = [1, 2, 3.0]
    t = 0.0
    tspan = (0.0, 1.0)
    prob = ODEProblem(dynamics!, x, tspan, p)
    integrator = init(prob, Tsit5())
    @show integrator.p
    integrator.p
    # @show result = log_func(x, t, integrator)
    # result.x |> sum
# end

end @run main(p) @show gradient(main, p)


- result
```julia
julia> include("test/log_func.jl")
integrator.p = 2.0
ERROR: LoadError: MethodError: no method matching haskey(::IRTools.Inner.Undefined, ::Symbol)
Closest candidates are:
  haskey(::Union{Tables.AbstractColumns, Tables.AbstractRow}, ::Symbol) at ~/.julia/packages/Tables/M26tI/src/Tables.jl:186
  haskey(::DataStructures.SortedMultiDict, ::Any) at ~/.julia/packages/DataStructures/vSp4s/src/sorted_multi_dict.jl:328
  haskey(::DataStructures.OrderedRobinDict, ::Any) at ~/.julia/packages/DataStructures/vSp4s/src/ordered_robin_dict.jl:305
  ...
Stacktrace:
  [1] rrule(#unused#::typeof(haskey), 855::IRTools.Inner.Undefined, 856::Symbol)
    @ ChainRules ~/.julia/packages/ChainRules/kkDLd/src/rulesets/Base/nondiff.jl:185
  [2] rrule(::Zygote.ZygoteRuleConfig{Zygote.Context}, ::Function, ::IRTools.Inner.Undefined, ::Symbol)
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/sHMAp/src/rules.jl:134
  [3] chain_rrule
    @ ~/.julia/packages/Zygote/bJn8I/src/compiler/chainrules.jl:216 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(haskey), ::IRTools.Inner.Undefined, ::Symbol)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:9
  [6] macro expansion
    @ ~/.julia/dev/SimulationLogger/src/macros.jl:123 [inlined]
  [7] _pullback
    @ ~/.julia/dev/SimulationLogger/test/log_func.jl:10 [inlined]
  [8] _pullback(::Zygote.Context, ::var"#dynamics!#90", ::Vector{Float64}, ::Vector{Float64}, ::Float64, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
  [9] _pullback
    @ ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:34 [inlined]
 [10] adjoint
    @ ~/.julia/packages/DiffEqBase/8q10H/src/chainrules.jl:134 [inlined]
 [11] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [12] _pullback
    @ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/perform_step/low_order_rk_perform_step.jl:627 [inlined]
 [13] _pullback(::Zygote.Context, ::typeof(initialize!), ::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, true, Vector{Float64}, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Vector{Float64}}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#90", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#dynamics!#90", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, ODEFunction{true, var"#dynamics!#90", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Vector{Float64}, Float64, Nothing, OrdinaryDiffEq.DefaultInit}, ::OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [14] _pullback
    @ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:456 [inlined]
 [15] _pullback(::Zygote.Context, ::OrdinaryDiffEq.var"##__init#501", ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::Nothing, ::Bool, ::Bool, ::Float64, ::Nothing, ::Float64, ::Bool, ::Bool, ::Rational{Int64}, ::Nothing, ::Nothing, ::Rational{Int64}, ::Int64, ::Int64, ::Int64, ::Nothing, ::Nothing, ::Rational{Int64}, ::Nothing, ::Bool, ::Int64, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DefaultInit, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(SciMLBase.__init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#90", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{Val{true}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [16] _pullback
    @ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:67 [inlined]
 [17] _pullback(::Zygote.Context, ::typeof(SciMLBase.__init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#90", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{Val{true}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [18] _pullback (repeats 4 times)
    @ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:67 [inlined]
 [19] _apply
    @ ./boot.jl:814 [inlined]
 [20] adjoint
    @ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [22] _pullback
    @ ~/.julia/packages/DiffEqBase/8q10H/src/solve.jl:28 [inlined]
 [23] _pullback(::Zygote.Context, ::DiffEqBase.var"##init_call#35", ::Bool, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(DiffEqBase.init_call), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#90", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [24] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [25] adjoint
    @ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
 [26] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [27] _pullback
    @ ~/.julia/packages/DiffEqBase/8q10H/src/solve.jl:15 [inlined]
 [28] _pullback(::Zygote.Context, ::typeof(DiffEqBase.init_call), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#90", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [29] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [30] adjoint
    @ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
 [31] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [32] _pullback
    @ ~/.julia/packages/DiffEqBase/8q10H/src/solve.jl:40 [inlined]
 [33] _pullback(::Zygote.Context, ::DiffEqBase.var"##init#36", ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#90", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [34] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [35] adjoint
    @ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
 [36] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [37] _pullback
    @ ~/.julia/packages/DiffEqBase/8q10H/src/solve.jl:33 [inlined]
 [38] _pullback(::Zygote.Context, ::typeof(init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#90", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [39] _pullback
    @ ~/.julia/dev/SimulationLogger/test/log_func.jl:20 [inlined]
 [40] _pullback(ctx::Zygote.Context, f::typeof(main), args::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [41] _pullback(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:34
 [42] pullback(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:40
 [43] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:75
 [44] top-level scope
    @ show.jl:1047
 [45] include(fname::String)
    @ Base.MainInclude ./client.jl:451
 [46] top-level scope
    @ REPL[5]:1
in expression starting at /Users/jinrae/.julia/dev/SimulationLogger/test/log_func.jl:28
JinraeKim commented 2 years ago

A different example (works well):

p = 2.0 function main(p) @Loggable function dynamics!(dx, x, p, t) @log x dx .= -p*x end

if hasmethod(dynamics!, Tuple{Any, Any, Any, Any, __LOG_INDICATOR__})

    # to avoid undefined error when not adding @Loggable
    log_func(x, t, integrator::DiffEqBase.DEIntegrator; kwargs...) = feedback_dynamics!(zero.(x), copy(x), integrator.p, t, __LOG_INDICATOR__(); kwargs...)
    x = [1, 2, 3.0]
    t = 0.0
    tspan = (0.0, 1.0)
    prob = ODEProblem(dynamics!, x, tspan, p)
    prob.p
    # integrator = init(prob, Tsit5())
    # @show integrator.p
    # integrator.p
    # @show result = log_func(x, t, integrator)
    # result.x |> sum
# end

end @run main(p) @show gradient(main, p)


- result
```julia
julia> include("test/log_func.jl")
gradient(main, p) = (1.0,)
(1.0,)
JinraeKim commented 2 years ago

I found that the above error may be affected by my custom function (e.g., @log). So I rewrite a test code, and it gave me a different error.

p = 2.0 function main(p) function dynamics!(dx, x, p, t)

@log x

    dx .= -p*x
end
# if hasmethod(dynamics!, Tuple{Any, Any, Any, Any, __LOG_INDICATOR__})
    # to avoid undefined error when not adding @Loggable
    # log_func(x, t, integrator::DiffEqBase.DEIntegrator; kwargs...) = feedback_dynamics!(zero.(x), copy(x), integrator.p, t, __LOG_INDICATOR__(); kwargs...)
    x = [1, 2, 3.0]
    t = 0.0
    tspan = (0.0, 1.0)
    prob = ODEProblem(dynamics!, x, tspan, p)
    integrator = init(prob, Tsit5())
    @show integrator.p
    integrator.p
    # @show result = log_func(x, t, integrator)
    # result.x |> sum
# end

end @run main(p) @show gradient(main, p)


- result
```julia
julia> includet("test/log_func.jl")
integrator.p = 2.0
ERROR: MethodError: Cannot `convert` an object of type Zygote.CompileError to an object of type Exception
Closest candidates are:
  convert(::Type{T}, ::T) where T at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/essentials.jl:218
Stacktrace:
 [1] includet(mod::Module, file::String)
   @ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1018
 [2] includet(file::String)
   @ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1023
 [3] top-level scope
   @ REPL[4]:1

caused by: MethodError: Cannot `convert` an object of type Zygote.CompileError to an object of type Exception
Closest candidates are:
  convert(::Type{T}, ::T) where T at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/essentials.jl:218
Stacktrace:
 [1] Revise.ReviseEvalException(loc::String, exc::Zygote.CompileError, stacktrace::Vector{Any})
   @ Revise ~/.julia/packages/Revise/WHZdV/src/types.jl:236
 [2] parse_source!(mod_exprs_sigs::OrderedCollections.OrderedDict{Module, OrderedCollections.OrderedDict{Revise.RelocatableExpr, Union{Nothing, Vector{Any}}}}, src::String, filename::String, mod::Module; mode::Symbol)
   @ Revise ~/.julia/packages/Revise/WHZdV/src/parsing.jl:55
 [3] parse_source!(mod_exprs_sigs::OrderedCollections.OrderedDict{Module, OrderedCollections.OrderedDict{Revise.RelocatableExpr, Union{Nothing, Vector{Any}}}}, filename::String, mod::Module; kwargs::Base.Pairs{Symbol, Symbol, Tuple{Symbol}, NamedTuple{(:mode,), Tuple{Symbol}}})
   @ Revise ~/.julia/packages/Revise/WHZdV/src/parsing.jl:27
 [4] #parse_source#11
   @ ~/.julia/packages/Revise/WHZdV/src/parsing.jl:10 [inlined]
 [5] track(mod::Module, file::String; mode::Symbol, kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:skip_include,), Tuple{Bool}}})
   @ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:901
 [6] includet(mod::Module, file::String)
   @ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1001
 [7] includet(file::String)
   @ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1023
 [8] top-level scope
   @ REPL[4]:1

caused by: Compiling Tuple{typeof(OrdinaryDiffEq.ode_determine_initdt), Vector{Float64}, Float64, Float64, Float64, Float64, Float64, typeof(DiffEqBase.ODE_DEFAULT_NORM), ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, true, Vector{Float64}, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Vector{Float64}}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Vector{Float64}, Float64, Nothing, OrdinaryDiffEq.DefaultInit}}: promotion of types Static.False and Int64 failed to change any arguments
Stacktrace:
  [1] error(::String, ::String, ::String)
    @ Base ./error.jl:42
  [2] sametype_error(input::Tuple{Static.False, Int64})
    @ Base ./promotion.jl:374
  [3] not_sametype(x::Tuple{Static.False, Int64}, y::Tuple{Static.False, Int64})
    @ Base ./promotion.jl:368
  [4] promote
    @ ./promotion.jl:351 [inlined]
  [5] ==(x::Static.False, y::Int64)
    @ Base ./promotion.jl:418
  [6] hash(x::Static.False, h::UInt64)
    @ Base ./float.jl:581
  [7] hash(A::Vector{Any}, h::UInt64)
    @ Base ./abstractarray.jl:2970
  [8] hash(x::Expr, h::UInt64)
    @ Base ./hashing.jl:93
  [9] hash(A::Vector{Any}, h::UInt64)
    @ Base ./abstractarray.jl:2970
 [10] hash
    @ ./hashing.jl:93 [inlined]
 [11] hash
    @ ./hashing.jl:20 [inlined]
 [12] hashindex
    @ ./dict.jl:169 [inlined]
 [13] ht_keyindex(h::Dict{Any, Any}, key::Expr)
    @ Base ./dict.jl:284
 [14] haskey
    @ ./dict.jl:552 [inlined]
 [15] SciML/DiffEqFlux.jl#12
    @ ~/.julia/packages/IRTools/isLV2/src/ir/wrap.jl:112 [inlined]
 [16] prewalk(f::IRTools.Inner.Wrap.var"#12#14"{Core.CodeInfo, Dict{Any, Any}}, x::Expr)
    @ MacroTools ~/.julia/packages/MacroTools/gME9C/src/utils.jl:134
 [17] (::IRTools.Inner.Wrap.var"#rename#13"{Core.CodeInfo, Dict{Any, Any}})(ex::Expr)
    @ IRTools.Inner.Wrap ~/.julia/packages/IRTools/isLV2/src/ir/wrap.jl:111
 [18] IRTools.Inner.IR(ci::Core.CodeInfo, nargs::Int64; meta::IRTools.Inner.Meta)
    @ IRTools.Inner.Wrap ~/.julia/packages/IRTools/isLV2/src/ir/wrap.jl:135
 [19] #IR#15
    @ ~/.julia/packages/IRTools/isLV2/src/ir/wrap.jl:142 [inlined]
 [20] IR
    @ ~/.julia/packages/IRTools/isLV2/src/ir/wrap.jl:142 [inlined]
 [21] _generate_pullback_via_decomposition(T::Type)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/emit.jl:101
 [22] #s3063#1218
    @ ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:28 [inlined]
 [23] var"#s3063#1218"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
 [24] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
 [25] _pullback
    @ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/integrators/integrator_interface.jl:329 [inlined]
 [26] _pullback(ctx::Zygote.Context, f::typeof(auto_dt_reset!), args::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, true, Vector{Float64}, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Vector{Float64}}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Vector{Float64}, Float64, Nothing, OrdinaryDiffEq.DefaultInit})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [27] _pullback
    @ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:504 [inlined]
 [28] _pullback(ctx::Zygote.Context, f::typeof(OrdinaryDiffEq.handle_dt!), args::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, true, Vector{Float64}, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Vector{Float64}}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Vector{Float64}, Float64, Nothing, OrdinaryDiffEq.DefaultInit})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [29] _pullback
    @ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:466 [inlined]
 [30] _pullback(::Zygote.Context, ::OrdinaryDiffEq.var"##__init#501", ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::Nothing, ::Bool, ::Bool, ::Float64, ::Nothing, ::Float64, ::Bool, ::Bool, ::Rational{Int64}, ::Nothing, ::Nothing, ::Rational{Int64}, ::Int64, ::Int64, ::Int64, ::Nothing, ::Nothing, ::Rational{Int64}, ::Nothing, ::Bool, ::Int64, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DefaultInit, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(SciMLBase.__init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{Val{true}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [31] _pullback
    @ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:67 [inlined]
 [32] _pullback(::Zygote.Context, ::typeof(SciMLBase.__init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{Val{true}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [33] _pullback (repeats 4 times)
    @ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:67 [inlined]
 [34] _apply
    @ ./boot.jl:814 [inlined]
 [35] adjoint
    @ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
 [36] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [37] _pullback
    @ ~/.julia/dev/DiffEqBase/src/solve.jl:29 [inlined]
 [38] _pullback(::Zygote.Context, ::DiffEqBase.var"##init_call#244", ::Bool, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(DiffEqBase.init_call), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [39] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [40] adjoint
    @ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
 [41] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [42] _pullback
    @ ~/.julia/dev/DiffEqBase/src/solve.jl:15 [inlined]
 [43] _pullback(::Zygote.Context, ::typeof(DiffEqBase.init_call), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [44] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [45] adjoint
    @ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
 [46] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [47] _pullback
    @ ~/.julia/dev/DiffEqBase/src/solve.jl:40 [inlined]
 [48] _pullback(::Zygote.Context, ::DiffEqBase.var"##init#233", ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [49] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [50] adjoint
    @ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
 [51] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [52] _pullback
    @ ~/.julia/dev/DiffEqBase/src/solve.jl:33 [inlined]
 [53] _pullback(::Zygote.Context, ::typeof(init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [54] _pullback
    @ ~/.julia/dev/SimulationLogger/test/log_func.jl:22 [inlined]
 [55] _pullback(ctx::Zygote.Context, f::typeof(main), args::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [56] _pullback(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:34
 [57] pullback(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:40
 [58] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:75
 [59] macro expansion
    @ show.jl:1047 [inlined]
 [60] top-level scope
    @ ~/.julia/dev/SimulationLogger/test/log_func.jl:30
 [61] eval
    @ ./boot.jl:373 [inlined]
 [62] parse_source!(mod_exprs_sigs::OrderedCollections.OrderedDict{Module, OrderedCollections.OrderedDict{Revise.RelocatableExpr, Union{Nothing, Vector{Any}}}}, src::String, filename::String, mod::Module; mode::Symbol)
    @ Revise ~/.julia/packages/Revise/WHZdV/src/parsing.jl:50
 [63] parse_source!(mod_exprs_sigs::OrderedCollections.OrderedDict{Module, OrderedCollections.OrderedDict{Revise.RelocatableExpr, Union{Nothing, Vector{Any}}}}, filename::String, mod::Module; kwargs::Base.Pairs{Symbol, Symbol, Tuple{Symbol}, NamedTuple{(:mode,), Tuple{Symbol}}})
    @ Revise ~/.julia/packages/Revise/WHZdV/src/parsing.jl:27
 [64] #parse_source#11
    @ ~/.julia/packages/Revise/WHZdV/src/parsing.jl:10 [inlined]
 [65] track(mod::Module, file::String; mode::Symbol, kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:skip_include,), Tuple{Bool}}})
    @ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:901
 [66] includet(mod::Module, file::String)
    @ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1001
 [67] includet(file::String)
    @ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1023
JinraeKim commented 2 years ago

And a detour for the goal was tried as follows, which also failed...

Is it impossible to utilise SavingCallback when using DiffEqSensitivity?

p = [2.0] function main(p) function dynamics!(x, p, t)

dx .= -p[1]*x

    -p[1]*x
end
# if hasmethod(dynamics!, Tuple{Any, Any, Any, Any, __LOG_INDICATOR__})
    # to avoid undefined error when not adding @Loggable
    # log_func(x, t, integrator::DiffEqBase.DEIntegrator; kwargs...) = dynamics!(zero.(x), copy(x), integrator.p, t, __LOG_INDICATOR__(); kwargs...)
    log_func(x, t, integrator::DiffEqBase.DEIntegrator; kwargs...) = (; x=copy(x))
    x = [1, 2, 3.0]
    t = 0.0
    tspan = (0.0, 1.0)
    prob = ODEProblem(dynamics!, x, tspan, p)
    saved_values = SavedValues(Float64, NamedTuple)
    cb = SavingCallback(log_func, saved_values; saveat=0:0.01:1.0)
    _ = solve(prob, Tsit5(); callback=cb, p=p)
    saved_values.saveval[end].x |> sum
    # sol = solve(prob, Tsit5(), p=p; saveat=0.01)
    # sol.u[end] |> sum
    # integrator = init(prob, Tsit5(); p=p)
    # integrator.p
    # solve!(integrator)
    # integrator.sol.u[end] |> sum
    # @show result = log_func(x, t, integrator)
    # result.x |> sum
# end

end @run main(p) @show gradient(main, p)


- result
```julia
julia> include("test/log_func.jl")
ERROR: LoadError: Mutating arrays is not supported -- called setindex!(::Vector{Float64}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#443#444"{Vector{Float64}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/array.jl:71
  [3] (::Zygote.var"#2342#back#445"{Zygote.var"#443#444"{Vector{Float64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ~/.julia/packages/DataStructures/vSp4s/src/heaps/arrays_as_heaps.jl:29 [inlined]
  [5] (::typeof(∂(percolate_down!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/DataStructures/vSp4s/src/heaps/arrays_as_heaps.jl:32 [inlined]
  [7] (::typeof(∂(percolate_down!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/DataStructures/vSp4s/src/heaps/arrays_as_heaps.jl:32 [inlined]
  [9] (::typeof(∂(percolate_down!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/DataStructures/vSp4s/src/heaps/arrays_as_heaps.jl:86 [inlined]
 [11] (::typeof(∂(heapify!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/DataStructures/vSp4s/src/heaps/arrays_as_heaps.jl:116 [inlined]
 [13] (::typeof(∂(heapify)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/DataStructures/vSp4s/src/heaps/binary_heap.jl:42 [inlined]
 [15] (::typeof(∂(DataStructures.BinaryHeap{Float64})))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/DataStructures/vSp4s/src/heaps/binary_heap.jl:51 [inlined]
 [17] (::typeof(∂(DataStructures.BinaryMinHeap{Float64})))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/DataStructures/vSp4s/src/heaps/binary_heap.jl:61 [inlined]
 [19] (::typeof(∂(DataStructures.BinaryMinHeap)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/.julia/packages/DiffEqCallbacks/YQhGJ/src/saving.jl:121 [inlined]
 [21] (::typeof(∂(#SavingCallback#29)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [22] Pullback
    @ ~/.julia/packages/DiffEqCallbacks/YQhGJ/src/saving.jl:119 [inlined]
 [23] (::typeof(∂(SavingCallback##kw)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [24] Pullback
    @ ~/.julia/dev/SimulationLogger/test/log_func.jl:25 [inlined]
 [25] (::typeof(∂(main)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [26] (::Zygote.var"#57#58"{typeof(∂(main))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:41
 [27] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:76
 [28] top-level scope
    @ show.jl:1047
 [29] include(fname::String)
    @ Base.MainInclude ./client.jl:451
 [30] top-level scope
    @ REPL[1]:1
in expression starting at /Users/jinrae/.julia/dev/SimulationLogger/test/log_func.jl:39
ChrisRackauckas commented 2 years ago

Indeed Zygote will not work with the integrator interface. Instead one should use callbacks which would support anything of the integrator interface. This is an upstream issue with Zygote and not fixable with the current AD systems, and even if it was it would not be compatible with all adjoint methods. I'll leave this issue open as a marker.

JinraeKim commented 2 years ago

Indeed Zygote will not work with the integrator interface. Instead one should use callbacks which would support anything of the integrator interface. This is an upstream issue with Zygote and not fixable with the current AD systems, and even if it was it would not be compatible with all adjoint methods. I'll leave this issue open as a marker.

Got it. Thank you, @ChrisRackauckas. Then my strategy should be using callbacks.

ChrisRackauckas commented 2 years ago

SavingCallback and FunctionCallingCallback are going to be rough to work with the differentiation system because they are explicitly trying to not directly save the outputs. I guess the question is, what are you trying to do with them any can't you do a solve then modify then outputs? Is it a memory thing?

JinraeKim commented 2 years ago

SavingCallback and FunctionCallingCallback are going to be rough to work with the differentiation system because they are explicitly trying to not directly save the outputs. I guess the question is, what are you trying to do with them any can't you do a solve then modify then outputs? Is it a memory thing?

Ok, I have to clarify what I wanna do. Sorry for the confusion. In my usual usage of solving ODE, I often need to get some data, which are a function of the DE parameter and state, p and u, respectively.

For example, some control law produces control command which is a function of time, state, and parameter, t, u, and p, respectively (namely, cmd = control(u, p, t)).

In this regard, I wrote SimulationLogger.jl to define RHS of ODE and to indicate which data will be logged simultaneously (and I think it's often convenient for hierarchical dynamical system simulation).

Now, I find out that DiffEqFlux.jl is a promising package to integrate (O)DE simulation and deep learning (especially for some simulation with control systems). So I want to make SimulationLogger.jl not only "log data", but also "differentiate them". If it's possible, users may be able to construct some hierarchical data at each timestamp and easily construct a loss function with respect to (possibly parts of) the data.

So I'm trying to find a way to constructing hierarchical data and differentiating them via DiffEqFlux.jl.

Because of lack of my background about backpropagation implementation, it's kinda messy to you. Sorry for that :)

The below code is a sample code that reflects my desire.

@Loggable function dynamics!(du, u, p, t)  # @Loggable indicates it will be logged; a syntax of SimulationLogger.jl
    ctrl = control_law(u, p, t)  # e.g., a neural network
    @log ctrl  # to log control input command
    @log u  # to log ODE state
    du .= ctrl
end
function predict()
    data = solve(...)  # with appropriate arguments
end
function loss(p)
    pred = predict()
    pred[end].ctrl |> sum + pred[end].u |> sum  # For example, I want to minimise the summation of control input and state at the terminal time, apart from the practical meaning of it
end
JinraeKim commented 2 years ago

One more: there was already an attempt to log hierarchical data with convenient macros via postprocesing, e.g., SimulationLogs.jl. It would work very well with deterministic ODE, and probably be compatible with AD tools with carefully written custom logging function. But I found that users have to log "stochastic data" in several cases (e.g., reinforcement learning). That's why I first used callbacks exported from DiffEq.jl.

ChrisRackauckas commented 2 years ago

I'll have to leave this open for now. It's not impossible but it's hard to do automatically. I think that the FunctionCallingCallback and SavingCallback stuff would instead have to be made as standard parts of the solver interface. The reason is because what we effectively want to do here is to define the adjoint not of du(t)/dp, but instead dv(t)/dp where v(t) = g(u(t),p,t). If you look at the index-1 DAE adjoint derivation in https://arxiv.org/abs/2001.04385, you can see that this g function does exist in the derivation and so if you are defining gradients by hand (https://diffeq.sciml.ai/stable/analysis/sensitivity/#Adjoint-Sensitivity-Analysis-via-adjoint_sensitivities-(Backpropogation)) you can define this g and make this happen. However, when relying on the ChainRules overloads to automatically capture solve, it is not possible to hijack the post-processing and so what is done is that g=identity is taken as the definition and then plugged into the later steps of AD. This is seen here:

https://github.com/SciML/DiffEqSensitivity.jl/blob/v6.64.0/src/concrete_solve.jl#L217-L264

The way to interpret that is really just g=identity, and it gets complicated handling all of the cases of the pullback but that's really all that is. So what we would need to do is have the SavingCallback or FunctionCallingCallback not be some random callback but be something that hooks into the system and changes the g used by the adjoint definition to instead be the output the user is requesting. Given how the callbacks tie into the system, that's not possible right now.

What would need to happen is that something akin to these callbacks would need to be added as standard keyword argument functions which then get implemented by all of the integrators and then specialized in the adjoint definition. This is something that I want to do, but hopefully this explains why it might take a little bit and will require a much larger change to the DiffEq system. That said, it would fix another issue with those callbacks which is that they don't tie into the interpolation system, so it will be a nice thing when completed, but it's a bit of a ways off.

JinraeKim commented 2 years ago

Thank you, @ChrisRackauckas. Great explanation!

It seems a quite big problem than I expected.