Open JinraeKim opened 2 years ago
A different example (works well):
using SimulationLogger
using DifferentialEquations
using DiffEqSensitivity
using Zygote
p = 2.0 function main(p) @Loggable function dynamics!(dx, x, p, t) @log x dx .= -p*x end
# to avoid undefined error when not adding @Loggable
log_func(x, t, integrator::DiffEqBase.DEIntegrator; kwargs...) = feedback_dynamics!(zero.(x), copy(x), integrator.p, t, __LOG_INDICATOR__(); kwargs...)
x = [1, 2, 3.0]
t = 0.0
tspan = (0.0, 1.0)
prob = ODEProblem(dynamics!, x, tspan, p)
prob.p
# integrator = init(prob, Tsit5())
# @show integrator.p
# integrator.p
# @show result = log_func(x, t, integrator)
# result.x |> sum
# end
end @run main(p) @show gradient(main, p)
- result
```julia
julia> include("test/log_func.jl")
gradient(main, p) = (1.0,)
(1.0,)
I found that the above error may be affected by my custom function (e.g., @log
).
So I rewrite a test code, and it gave me a different error.
using SimulationLogger
# using DifferentialEquations
# using DiffEqSensitivity
using OrdinaryDiffEq
using Zygote
using DiffEqBase
p = 2.0 function main(p) function dynamics!(dx, x, p, t)
dx .= -p*x
end
# if hasmethod(dynamics!, Tuple{Any, Any, Any, Any, __LOG_INDICATOR__})
# to avoid undefined error when not adding @Loggable
# log_func(x, t, integrator::DiffEqBase.DEIntegrator; kwargs...) = feedback_dynamics!(zero.(x), copy(x), integrator.p, t, __LOG_INDICATOR__(); kwargs...)
x = [1, 2, 3.0]
t = 0.0
tspan = (0.0, 1.0)
prob = ODEProblem(dynamics!, x, tspan, p)
integrator = init(prob, Tsit5())
@show integrator.p
integrator.p
# @show result = log_func(x, t, integrator)
# result.x |> sum
# end
end @run main(p) @show gradient(main, p)
- result
```julia
julia> includet("test/log_func.jl")
integrator.p = 2.0
ERROR: MethodError: Cannot `convert` an object of type Zygote.CompileError to an object of type Exception
Closest candidates are:
convert(::Type{T}, ::T) where T at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/essentials.jl:218
Stacktrace:
[1] includet(mod::Module, file::String)
@ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1018
[2] includet(file::String)
@ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1023
[3] top-level scope
@ REPL[4]:1
caused by: MethodError: Cannot `convert` an object of type Zygote.CompileError to an object of type Exception
Closest candidates are:
convert(::Type{T}, ::T) where T at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/essentials.jl:218
Stacktrace:
[1] Revise.ReviseEvalException(loc::String, exc::Zygote.CompileError, stacktrace::Vector{Any})
@ Revise ~/.julia/packages/Revise/WHZdV/src/types.jl:236
[2] parse_source!(mod_exprs_sigs::OrderedCollections.OrderedDict{Module, OrderedCollections.OrderedDict{Revise.RelocatableExpr, Union{Nothing, Vector{Any}}}}, src::String, filename::String, mod::Module; mode::Symbol)
@ Revise ~/.julia/packages/Revise/WHZdV/src/parsing.jl:55
[3] parse_source!(mod_exprs_sigs::OrderedCollections.OrderedDict{Module, OrderedCollections.OrderedDict{Revise.RelocatableExpr, Union{Nothing, Vector{Any}}}}, filename::String, mod::Module; kwargs::Base.Pairs{Symbol, Symbol, Tuple{Symbol}, NamedTuple{(:mode,), Tuple{Symbol}}})
@ Revise ~/.julia/packages/Revise/WHZdV/src/parsing.jl:27
[4] #parse_source#11
@ ~/.julia/packages/Revise/WHZdV/src/parsing.jl:10 [inlined]
[5] track(mod::Module, file::String; mode::Symbol, kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:skip_include,), Tuple{Bool}}})
@ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:901
[6] includet(mod::Module, file::String)
@ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1001
[7] includet(file::String)
@ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1023
[8] top-level scope
@ REPL[4]:1
caused by: Compiling Tuple{typeof(OrdinaryDiffEq.ode_determine_initdt), Vector{Float64}, Float64, Float64, Float64, Float64, Float64, typeof(DiffEqBase.ODE_DEFAULT_NORM), ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, true, Vector{Float64}, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Vector{Float64}}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Vector{Float64}, Float64, Nothing, OrdinaryDiffEq.DefaultInit}}: promotion of types Static.False and Int64 failed to change any arguments
Stacktrace:
[1] error(::String, ::String, ::String)
@ Base ./error.jl:42
[2] sametype_error(input::Tuple{Static.False, Int64})
@ Base ./promotion.jl:374
[3] not_sametype(x::Tuple{Static.False, Int64}, y::Tuple{Static.False, Int64})
@ Base ./promotion.jl:368
[4] promote
@ ./promotion.jl:351 [inlined]
[5] ==(x::Static.False, y::Int64)
@ Base ./promotion.jl:418
[6] hash(x::Static.False, h::UInt64)
@ Base ./float.jl:581
[7] hash(A::Vector{Any}, h::UInt64)
@ Base ./abstractarray.jl:2970
[8] hash(x::Expr, h::UInt64)
@ Base ./hashing.jl:93
[9] hash(A::Vector{Any}, h::UInt64)
@ Base ./abstractarray.jl:2970
[10] hash
@ ./hashing.jl:93 [inlined]
[11] hash
@ ./hashing.jl:20 [inlined]
[12] hashindex
@ ./dict.jl:169 [inlined]
[13] ht_keyindex(h::Dict{Any, Any}, key::Expr)
@ Base ./dict.jl:284
[14] haskey
@ ./dict.jl:552 [inlined]
[15] SciML/DiffEqFlux.jl#12
@ ~/.julia/packages/IRTools/isLV2/src/ir/wrap.jl:112 [inlined]
[16] prewalk(f::IRTools.Inner.Wrap.var"#12#14"{Core.CodeInfo, Dict{Any, Any}}, x::Expr)
@ MacroTools ~/.julia/packages/MacroTools/gME9C/src/utils.jl:134
[17] (::IRTools.Inner.Wrap.var"#rename#13"{Core.CodeInfo, Dict{Any, Any}})(ex::Expr)
@ IRTools.Inner.Wrap ~/.julia/packages/IRTools/isLV2/src/ir/wrap.jl:111
[18] IRTools.Inner.IR(ci::Core.CodeInfo, nargs::Int64; meta::IRTools.Inner.Meta)
@ IRTools.Inner.Wrap ~/.julia/packages/IRTools/isLV2/src/ir/wrap.jl:135
[19] #IR#15
@ ~/.julia/packages/IRTools/isLV2/src/ir/wrap.jl:142 [inlined]
[20] IR
@ ~/.julia/packages/IRTools/isLV2/src/ir/wrap.jl:142 [inlined]
[21] _generate_pullback_via_decomposition(T::Type)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/emit.jl:101
[22] #s3063#1218
@ ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:28 [inlined]
[23] var"#s3063#1218"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[24] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:580
[25] _pullback
@ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/integrators/integrator_interface.jl:329 [inlined]
[26] _pullback(ctx::Zygote.Context, f::typeof(auto_dt_reset!), args::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, true, Vector{Float64}, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Vector{Float64}}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Vector{Float64}, Float64, Nothing, OrdinaryDiffEq.DefaultInit})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[27] _pullback
@ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:504 [inlined]
[28] _pullback(ctx::Zygote.Context, f::typeof(OrdinaryDiffEq.handle_dt!), args::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, true, Vector{Float64}, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Vector{Float64}}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Vector{Float64}, Float64, Nothing, OrdinaryDiffEq.DefaultInit})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[29] _pullback
@ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:466 [inlined]
[30] _pullback(::Zygote.Context, ::OrdinaryDiffEq.var"##__init#501", ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::Nothing, ::Bool, ::Bool, ::Float64, ::Nothing, ::Float64, ::Bool, ::Bool, ::Rational{Int64}, ::Nothing, ::Nothing, ::Rational{Int64}, ::Int64, ::Int64, ::Int64, ::Nothing, ::Nothing, ::Rational{Int64}, ::Nothing, ::Bool, ::Int64, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DefaultInit, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(SciMLBase.__init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{Val{true}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[31] _pullback
@ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:67 [inlined]
[32] _pullback(::Zygote.Context, ::typeof(SciMLBase.__init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{Val{true}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[33] _pullback (repeats 4 times)
@ ~/.julia/packages/OrdinaryDiffEq/WD8cC/src/solve.jl:67 [inlined]
[34] _apply
@ ./boot.jl:814 [inlined]
[35] adjoint
@ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
[36] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[37] _pullback
@ ~/.julia/dev/DiffEqBase/src/solve.jl:29 [inlined]
[38] _pullback(::Zygote.Context, ::DiffEqBase.var"##init_call#244", ::Bool, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(DiffEqBase.init_call), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[39] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:814
[40] adjoint
@ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
[41] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[42] _pullback
@ ~/.julia/dev/DiffEqBase/src/solve.jl:15 [inlined]
[43] _pullback(::Zygote.Context, ::typeof(DiffEqBase.init_call), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[44] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:814
[45] adjoint
@ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
[46] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[47] _pullback
@ ~/.julia/dev/DiffEqBase/src/solve.jl:40 [inlined]
[48] _pullback(::Zygote.Context, ::DiffEqBase.var"##init#233", ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[49] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:814
[50] adjoint
@ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
[51] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[52] _pullback
@ ~/.julia/dev/DiffEqBase/src/solve.jl:33 [inlined]
[53] _pullback(::Zygote.Context, ::typeof(init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Float64, ODEFunction{true, var"#dynamics!#102", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[54] _pullback
@ ~/.julia/dev/SimulationLogger/test/log_func.jl:22 [inlined]
[55] _pullback(ctx::Zygote.Context, f::typeof(main), args::Float64)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[56] _pullback(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:34
[57] pullback(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:40
[58] gradient(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:75
[59] macro expansion
@ show.jl:1047 [inlined]
[60] top-level scope
@ ~/.julia/dev/SimulationLogger/test/log_func.jl:30
[61] eval
@ ./boot.jl:373 [inlined]
[62] parse_source!(mod_exprs_sigs::OrderedCollections.OrderedDict{Module, OrderedCollections.OrderedDict{Revise.RelocatableExpr, Union{Nothing, Vector{Any}}}}, src::String, filename::String, mod::Module; mode::Symbol)
@ Revise ~/.julia/packages/Revise/WHZdV/src/parsing.jl:50
[63] parse_source!(mod_exprs_sigs::OrderedCollections.OrderedDict{Module, OrderedCollections.OrderedDict{Revise.RelocatableExpr, Union{Nothing, Vector{Any}}}}, filename::String, mod::Module; kwargs::Base.Pairs{Symbol, Symbol, Tuple{Symbol}, NamedTuple{(:mode,), Tuple{Symbol}}})
@ Revise ~/.julia/packages/Revise/WHZdV/src/parsing.jl:27
[64] #parse_source#11
@ ~/.julia/packages/Revise/WHZdV/src/parsing.jl:10 [inlined]
[65] track(mod::Module, file::String; mode::Symbol, kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:skip_include,), Tuple{Bool}}})
@ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:901
[66] includet(mod::Module, file::String)
@ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1001
[67] includet(file::String)
@ Revise ~/.julia/packages/Revise/WHZdV/src/packagedef.jl:1023
And a detour for the goal was tried as follows, which also failed...
Is it impossible to utilise SavingCallback
when using DiffEqSensitivity?
using SimulationLogger
# using DifferentialEquations
# using DiffEqSensitivity
using DifferentialEquations
using Zygote
# using DiffEqBase
using DiffEqSensitivity
p = [2.0] function main(p) function dynamics!(x, p, t)
-p[1]*x
end
# if hasmethod(dynamics!, Tuple{Any, Any, Any, Any, __LOG_INDICATOR__})
# to avoid undefined error when not adding @Loggable
# log_func(x, t, integrator::DiffEqBase.DEIntegrator; kwargs...) = dynamics!(zero.(x), copy(x), integrator.p, t, __LOG_INDICATOR__(); kwargs...)
log_func(x, t, integrator::DiffEqBase.DEIntegrator; kwargs...) = (; x=copy(x))
x = [1, 2, 3.0]
t = 0.0
tspan = (0.0, 1.0)
prob = ODEProblem(dynamics!, x, tspan, p)
saved_values = SavedValues(Float64, NamedTuple)
cb = SavingCallback(log_func, saved_values; saveat=0:0.01:1.0)
_ = solve(prob, Tsit5(); callback=cb, p=p)
saved_values.saveval[end].x |> sum
# sol = solve(prob, Tsit5(), p=p; saveat=0.01)
# sol.u[end] |> sum
# integrator = init(prob, Tsit5(); p=p)
# integrator.p
# solve!(integrator)
# integrator.sol.u[end] |> sum
# @show result = log_func(x, t, integrator)
# result.x |> sum
# end
end @run main(p) @show gradient(main, p)
- result
```julia
julia> include("test/log_func.jl")
ERROR: LoadError: Mutating arrays is not supported -- called setindex!(::Vector{Float64}, _...)
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#443#444"{Vector{Float64}})(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/array.jl:71
[3] (::Zygote.var"#2342#back#445"{Zygote.var"#443#444"{Vector{Float64}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ~/.julia/packages/DataStructures/vSp4s/src/heaps/arrays_as_heaps.jl:29 [inlined]
[5] (::typeof(∂(percolate_down!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/DataStructures/vSp4s/src/heaps/arrays_as_heaps.jl:32 [inlined]
[7] (::typeof(∂(percolate_down!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/DataStructures/vSp4s/src/heaps/arrays_as_heaps.jl:32 [inlined]
[9] (::typeof(∂(percolate_down!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[10] Pullback
@ ~/.julia/packages/DataStructures/vSp4s/src/heaps/arrays_as_heaps.jl:86 [inlined]
[11] (::typeof(∂(heapify!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[12] Pullback
@ ~/.julia/packages/DataStructures/vSp4s/src/heaps/arrays_as_heaps.jl:116 [inlined]
[13] (::typeof(∂(heapify)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[14] Pullback
@ ~/.julia/packages/DataStructures/vSp4s/src/heaps/binary_heap.jl:42 [inlined]
[15] (::typeof(∂(DataStructures.BinaryHeap{Float64})))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[16] Pullback
@ ~/.julia/packages/DataStructures/vSp4s/src/heaps/binary_heap.jl:51 [inlined]
[17] (::typeof(∂(DataStructures.BinaryMinHeap{Float64})))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[18] Pullback
@ ~/.julia/packages/DataStructures/vSp4s/src/heaps/binary_heap.jl:61 [inlined]
[19] (::typeof(∂(DataStructures.BinaryMinHeap)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[20] Pullback
@ ~/.julia/packages/DiffEqCallbacks/YQhGJ/src/saving.jl:121 [inlined]
[21] (::typeof(∂(#SavingCallback#29)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[22] Pullback
@ ~/.julia/packages/DiffEqCallbacks/YQhGJ/src/saving.jl:119 [inlined]
[23] (::typeof(∂(SavingCallback##kw)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[24] Pullback
@ ~/.julia/dev/SimulationLogger/test/log_func.jl:25 [inlined]
[25] (::typeof(∂(main)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[26] (::Zygote.var"#57#58"{typeof(∂(main))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:41
[27] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:76
[28] top-level scope
@ show.jl:1047
[29] include(fname::String)
@ Base.MainInclude ./client.jl:451
[30] top-level scope
@ REPL[1]:1
in expression starting at /Users/jinrae/.julia/dev/SimulationLogger/test/log_func.jl:39
Indeed Zygote will not work with the integrator interface. Instead one should use callbacks which would support anything of the integrator interface. This is an upstream issue with Zygote and not fixable with the current AD systems, and even if it was it would not be compatible with all adjoint methods. I'll leave this issue open as a marker.
Indeed Zygote will not work with the integrator interface. Instead one should use callbacks which would support anything of the integrator interface. This is an upstream issue with Zygote and not fixable with the current AD systems, and even if it was it would not be compatible with all adjoint methods. I'll leave this issue open as a marker.
Got it. Thank you, @ChrisRackauckas. Then my strategy should be using callbacks.
SavingCallback
not compatible with any adjoint methods (an example)?SavingCallback and FunctionCallingCallback are going to be rough to work with the differentiation system because they are explicitly trying to not directly save the outputs. I guess the question is, what are you trying to do with them any can't you do a solve then modify then outputs? Is it a memory thing?
SavingCallback and FunctionCallingCallback are going to be rough to work with the differentiation system because they are explicitly trying to not directly save the outputs. I guess the question is, what are you trying to do with them any can't you do a solve then modify then outputs? Is it a memory thing?
Ok, I have to clarify what I wanna do. Sorry for the confusion.
In my usual usage of solving ODE, I often need to get some data, which are a function of the DE parameter and state, p
and u
, respectively.
For example, some control law produces control command which is a function of time, state, and parameter, t
, u
, and p
, respectively (namely, cmd = control(u, p, t)
).
In this regard, I wrote SimulationLogger.jl to define RHS of ODE and to indicate which data will be logged simultaneously (and I think it's often convenient for hierarchical dynamical system simulation).
Now, I find out that DiffEqFlux.jl is a promising package to integrate (O)DE simulation and deep learning (especially for some simulation with control systems). So I want to make SimulationLogger.jl not only "log data", but also "differentiate them". If it's possible, users may be able to construct some hierarchical data at each timestamp and easily construct a loss function with respect to (possibly parts of) the data.
So I'm trying to find a way to constructing hierarchical data and differentiating them via DiffEqFlux.jl.
Because of lack of my background about backpropagation implementation, it's kinda messy to you. Sorry for that :)
The below code is a sample code that reflects my desire.
@Loggable function dynamics!(du, u, p, t) # @Loggable indicates it will be logged; a syntax of SimulationLogger.jl
ctrl = control_law(u, p, t) # e.g., a neural network
@log ctrl # to log control input command
@log u # to log ODE state
du .= ctrl
end
function predict()
data = solve(...) # with appropriate arguments
end
function loss(p)
pred = predict()
pred[end].ctrl |> sum + pred[end].u |> sum # For example, I want to minimise the summation of control input and state at the terminal time, apart from the practical meaning of it
end
One more: there was already an attempt to log hierarchical data with convenient macros via postprocesing, e.g., SimulationLogs.jl. It would work very well with deterministic ODE, and probably be compatible with AD tools with carefully written custom logging function. But I found that users have to log "stochastic data" in several cases (e.g., reinforcement learning). That's why I first used callbacks exported from DiffEq.jl.
I'll have to leave this open for now. It's not impossible but it's hard to do automatically. I think that the FunctionCallingCallback and SavingCallback stuff would instead have to be made as standard parts of the solver interface. The reason is because what we effectively want to do here is to define the adjoint not of du(t)/dp
, but instead dv(t)/dp
where v(t) = g(u(t),p,t)
. If you look at the index-1 DAE adjoint derivation in https://arxiv.org/abs/2001.04385, you can see that this g
function does exist in the derivation and so if you are defining gradients by hand (https://diffeq.sciml.ai/stable/analysis/sensitivity/#Adjoint-Sensitivity-Analysis-via-adjoint_sensitivities-(Backpropogation)) you can define this g
and make this happen. However, when relying on the ChainRules overloads to automatically capture solve
, it is not possible to hijack the post-processing and so what is done is that g=identity
is taken as the definition and then plugged into the later steps of AD. This is seen here:
https://github.com/SciML/DiffEqSensitivity.jl/blob/v6.64.0/src/concrete_solve.jl#L217-L264
The way to interpret that is really just g=identity
, and it gets complicated handling all of the cases of the pullback but that's really all that is. So what we would need to do is have the SavingCallback
or FunctionCallingCallback
not be some random callback but be something that hooks into the system and changes the g
used by the adjoint definition to instead be the output the user is requesting. Given how the callbacks tie into the system, that's not possible right now.
What would need to happen is that something akin to these callbacks would need to be added as standard keyword argument functions which then get implemented by all of the integrators and then specialized in the adjoint definition. This is something that I want to do, but hopefully this explains why it might take a little bit and will require a much larger change to the DiffEq system. That said, it would fix another issue with those callbacks which is that they don't tie into the interpolation system, so it will be a nice thing when completed, but it's a bit of a ways off.
Thank you, @ChrisRackauckas. Great explanation!
It seems a quite big problem than I expected.
Hi, my goal is described in SciML/DiffEqFlux.jl#661, and I tried to narrow down the whole problem into a small one to determine whether it's possible to take backpropagation with integrator interface.
The following example is in this commit, branch
iss11
, SimulationLogger.jl. I found thatgradient
from Zygote.jl does not work with the integrator constructed byinit
, while we can take gradient if the output is changed asprob.p
(whereprob |> typeof <: ODEProblem
).What would be the best workaround for this case? Or, if this is a new feature, I would like to request this new feature for flexible usage of DiffEqFlux.jl.
p = 2.0 function main(p) @Loggable function dynamics!(dx, x, p, t) @log x dx .= -p*x end
if hasmethod(dynamics!, Tuple{Any, Any, Any, Any, __LOG_INDICATOR__})
end @run main(p) @show gradient(main, p)