SciML / DifferentialEquations.jl

Multi-language suite for high-performance solvers of differential equations and scientific machine learning (SciML) components. Ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), differential-algebraic equations (DAEs), and more in Julia.
https://docs.sciml.ai/DiffEqDocs/stable/
Other
2.84k stars 224 forks source link

Zygote AD of symbol-indexing into ODESolution #746

Open bgctw opened 3 years ago

bgctw commented 3 years ago

ModelingToolkit allows to index into ODESolution via a symbol. However, currently, this causes problems with Optimization using Zygote gradients.

I tried working on the issue but need to learn more and need guidance with DA and DifferentialEquations. There is a related discourse topic and an issue at ModelingToolkit.

The following example demonstrates the issue.

using ModelingToolkit, OrdinaryDiffEq
using DiffEqBase

@parameters α β δ γ
@variables t x(t) y(t) dx(t)
D = Differential(t)
eqs = [
  dx ~ α*x - β*x*y,  # testing observed variables
  D(x) ~ dx,
  D(y) ~ -δ*y + γ*x*y
]
@named lv = ODESystem(eqs)
syss = structural_simplify(lv) 
parms = [α => 1.5, β => 1.0, δ => 3.0, γ => 1.0]
x0 = [x => 1.0, y => 1.0]
tsteps = 0.0:0.1:10.0
prob = ODEProblem(syss, x0, extrema(tsteps), parms, jac = true)
soltrue = solve(prob,  Tsit5(), saveat = tsteps);
popt0 = [1.1]

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym) 
  function ODESolution_getindex_pullback(Δ)
    @show Δ
    @show length(VA)
    @show VA
    @show VA.u
    # convert symbol to index
    i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
    @show i
    # similar to VectorOfArray: return zero for non-matching indices
    Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
    (NO_FIELDS, Δ′)
    # TODO: care for observed
  end  
  VA[sym], ODESolution_getindex_pullback(Δ)
end

f1(p) = soltrue[x][1] * p[1] # note the indexing by [x]
f1(popt0)
#using Zygote
gr = Zygote.gradient(f1, popt0) # calls the failing rule for VectorOfArrays instead of above rule
ChrisRackauckas commented 3 years ago

@YingboMa @DhairyaLGandhi , I think we will need a pretty specific overload here?

DhairyaLGandhi commented 3 years ago

literal_getindex with symbol possibly?

ChrisRackauckas commented 3 years ago

No, https://github.com/SciML/SciMLBase.jl/blob/master/src/solutions/solution_interface.jl#L37-L53 is what it is.

If it's a known symbol, then it's essentially @nograd just a translation to an index.

If it's not a known symbol, there's a function that is called that essentially fakes indexing. I assume that part would have to be differentiated?

DhairyaLGandhi commented 3 years ago

Okay, if it's calling into a function, then I understand needing to differentiate it. Thanks for the link!

ChrisRackauckas commented 3 years ago

I think we only need to differentiate it half of the time though? It might be fixed by a few more @nograds on symbol handling stuff.

DhairyaLGandhi commented 3 years ago

Right, so we would basically want to teach zygote which symbols it needs to ignore and only differentiate when needed.

lamorton commented 3 years ago

Now that ModelingToolkit#151 landed, the traceback is different but the problem persists. Instead of going into the ZygoteRule from RecursiveArrayTools, the indexing gets handled (incorrectly) by Zygote itself.

julia> gr = Zygote.gradient(f1, popt0) 
ERROR: ArgumentError: invalid index: x(t) of type Num
Stacktrace:
  [1] to_index(i::Num)
    @ Base ./indices.jl:300
  [2] to_index(A::Matrix{Float64}, i::Num)
    @ Base ./indices.jl:277
  [3] to_indices
    @ ./indices.jl:333 [inlined]
  [4] to_indices
    @ ./indices.jl:325 [inlined]
  [5] view
    @ ./subarray.jl:176 [inlined]
  [6] (::Zygote.var"#408#410"{2, Float64, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:jac,), Tuple{Bool}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Tuple{Num}})(dy::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/array.jl:43
  [7] (::Zygote.var"#2248#back#404"{Zygote.var"#408#410"{2, Float64, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:jac,), Tuple{Bool}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, ModelingToolkit.var"#f#165"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x19b71b7a, 0x460a6fa9, 0x626990a7, 0x75215ee8, 0x03e07fc2)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf42d3854, 0x96c39041, 0xfccc2855, 0xb95c5d08, 0xe37149b7)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#_jac#169"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xf639a50e, 0x1c43f489, 0xc1fc55da, 0x039b3c3f, 0x12cc3d06)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4ff8155b, 0x50a871bb, 0x5fd810b0, 0x2752d931, 0xe4cfe9e5)}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#167#generated_observed#172"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Tuple{Num}}})(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [8] Pullback
    @ ./REPL[38]:1 [inlined]
  [9] (::typeof(∂(f1)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#46#47"{typeof(∂(f1))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:41
 [11] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:59
 [12] top-level scope
    @ REPL[41]:1
lamorton commented 3 years ago

These may be relevant to the proposed fix: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/239 and https://github.com/FluxML/Zygote.jl/issues/811

ChrisRackauckas commented 3 years ago

So is the fix just to overload https://github.com/SciML/RecursiveArrayTools.jl/pull/151/files#diff-507801ba9a07a606d6519898b2d4da0592747f44c3940056176b4e03f088c5aeR46-R54 ?

lamorton commented 3 years ago

Yup. I also figured out the issue with @bgctw 's prototype: you need to explicitly use Base.getindex here:

function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym) 
lamorton commented 3 years ago

Looks like handling the observed variables is tricky b/c they implicitly depend on the equations that relate them to the solution array.

lamorton commented 3 years ago

Here's a partial solution for the states, but not for the observed variables. The array construction is a bit of a kludge.

using SciMLBase #B/c modelingtoolkit doesn't reexport issymbollike
ZygoteRules.@adjoint function Base.getindex(VA::ODESolution, sym::Num) 
  function ODESolution_getindex_pullback(Δ)
    # convert symbol to index
    i = SciMLBase.issymbollike(sym) ? SciMLBase.sym_to_index(sym, VA) : sym
    # similar to VectorOfArray: return zero for non-matching indices
    if i == nothing
      throw("Error: gradient of observed symbol is not defined yet")
      Zygote.pullback(observed,VA,sym,:)
    else
      Δ′ = [ [i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] for (x,j) in zip(VA.u, 1:length(VA))]
      (Δ′,nothing)
    end
  end
  VA[sym], ODESolution_getindex_pullback
end
ChrisRackauckas commented 3 years ago

Yeah I think it needs to pass Δ into the pullback of the observed function itself?

lamorton commented 3 years ago

Ahh, that's probably right. I guess I also need a dummy ODESolution just to hang the derivatives for sol.prob.p on?

ChrisRackauckas commented 3 years ago

yeah I think so.

lamorton commented 3 years ago

Here's a functioning prototype for the simpler case where a single timeslice is chosen as well:

ZygoteRules.@adjoint function Base.getindex(VA::ODESolution, sym::Num,j::Int) 
  function ODESolution_getindex_pullback(Δ)
    # convert symbol to index
    i = SciMLBase.issymbollike(sym) ? SciMLBase.sym_to_index(sym, VA) : sym
    # similar to VectorOfArray: return zero for non-matching indices
    if i === nothing
      getter = SciMLBase.getobserved(VA)
      grz = Zygote.pullback(getter,sym,VA.u[j],VA.prob.p,VA.t[j])[2](Δ)
      du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] 
      dp = grz[3] # pullback for p
      dprob = remake(VA.prob,p=dp)
      T = eltype(eltype(VA.u))
      N = length(VA.prob.p)
      Δ′ = ODESolution{T,N,typeof(du),Nothing,Nothing,typeof(VA.t),typeof(VA.k), typeof(dprob),typeof(VA.alg),typeof(VA.interp),typeof(VA.destats)}(du,nothing,nothing,VA.t,VA.k,dprob,VA.alg,VA.interp,VA.dense,0,VA.destats,VA.retcode)
      (Δ′,nothing,nothing)
    else
      Δ′ = [ m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : zero(VA.u[1]) for m in 1:length(VA.u)]
      (Δ′,nothing,nothing)
    end
  end
  VA[sym,j], ODESolution_getindex_pullback
end

I'm stuck on how to treat the derivatives for the parameters in the general case -- they pick up an extra dimension, due to the input Δ being an array over time.

ChrisRackauckas commented 3 years ago

Let's break this problem down into steps. Could you PR what you have and throw an error on the not handled case? And then it can continue to improve.