Closed bgctw closed 3 years ago
I just hit a related/same error with saveat
.
using ModelingToolkit, OrdinaryDiffEq, Plots
@parameters t τ P
@variables E(t) η(t)
D = Differential(t)
eqs = [η ~ tanh(E), #Algebraic
D(E) ~ P*η - E/τ, #Differential
]
sys = ODESystem(eqs)
sys_simp = structural_simplify(sys)
u0 = [E => 0.3, ]
p = [P => 10.0, τ => 0.1]
tspan = (0.0,1.0)
prob = ODEProblem(sys_simp,u0,tspan,p,jac=true)
solat = solve(prob,Tsit5(),saveat=0.01)
#This works either way:
solat[η]
#This fails with the saveat
plot(solat,vars=[η])
It's weird b/c both solutions have the same type:
sol = solve(prob,Tsit5())
typeof(sol) == typeof(solat)
Error message from the plot line:
ERROR: MethodError: no method matching tanh(::Array{Float64,1})
Closest candidates are:
tanh(::Float16) at math.jl:1144
tanh(::BigFloat) at mpfr.jl:603
tanh(::Missing) at math.jl:1197
...
Stacktrace:
[1] macro expansion at /Users/lmorton/.julia/packages/SymbolicUtils/9iQGH/src/code.jl:283 [inlined]
[2] macro expansion at /Users/lmorton/.julia/packages/RuntimeGeneratedFunctions/3SZ1T/src/RuntimeGeneratedFunctions.jl:124 [inlined]
[3] macro expansion at ./none:0 [inlined]
[4] generated_callfunc at ./none:0 [inlined]
[5] (::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#814"), Symbol("##arg#815"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xa4d463e3, 0x63e11be2, 0xa80a5ca2, 0x9939d0b5, 0xbc0a3778)})(::SubArray{Array{Float64,1},1,Array{Array{Float64,1},2},Tuple{Base.Slice{Base.OneTo{Int64}},Int64},true}, ::Array{Float64,1}, ::Float64) at /Users/lmorton/.julia/packages/RuntimeGeneratedFunctions/3SZ1T/src/RuntimeGeneratedFunctions.jl:112
[6] (::ModelingToolkit.var"#135#generated_observed#158"{Bool,ODESystem,Dict{Any,Any}})(::Num, ::SubArray{Array{Float64,1},1,Array{Array{Float64,1},2},Tuple{Base.Slice{Base.OneTo{Int64}},Int64},true}, ::Array{Float64,1}, ::Float64) at /Users/lmorton/.julia/packages/ModelingToolkit/DOKSJ/src/systems/diffeqs/abstractodesystem.jl:213
[7] _broadcast_getindex_evalf at ./broadcast.jl:648 [inlined]
[8] _broadcast_getindex at ./broadcast.jl:621 [inlined]
[9] getindex at ./broadcast.jl:575 [inlined]
[10] copy(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},ModelingToolkit.var"#135#generated_observed#158"{Bool,ODESystem,Dict{Any,Any}},Tuple{Tuple{Num},Array{SubArray{Array{Float64,1},1,Array{Array{Float64,1},2},Tuple{Base.Slice{Base.OneTo{Int64}},Int64},true},1},Tuple{Array{Float64,1}},Array{Float64,1}}}) at ./broadcast.jl:876
[11] materialize at ./broadcast.jl:837 [inlined]
[12] u_n(::Array{Array{Float64,1},1}, ::Num, ::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,ModelingToolkit.var"#f#151"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xee580426, 0x65e1a321, 0x4af8a7bc, 0xd22398a2, 0x8ce03fe0)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#780"), Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xa36e7a51, 0x79c3d7e7, 0x45eeda34, 0x3e16c63c, 0xd9bd7cec)}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,ModelingToolkit.var"#_jac#155"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0x0bc4f15d, 0x292087c0, 0xe0782571, 0x4a2d3328, 0x56cd814c)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#783"), Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xd43242d8, 0x38ee3feb, 0x71673919, 0x67d7d6b0, 0x9bc144db)}},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Symbol,1},Symbol,ModelingToolkit.var"#135#generated_observed#158"{Bool,ODESystem,Dict{Any,Any}},Nothing},Base.Iterators.Pairs{Symbol,Bool,Tuple{Symbol},NamedTuple{(:jac,),Tuple{Bool}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,ModelingToolkit.var"#f#151"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xee580426, 0x65e1a321, 0x4af8a7bc, 0xd22398a2, 0x8ce03fe0)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#780"), Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xa36e7a51, 0x79c3d7e7, 0x45eeda34, 0x3e16c63c, 0xd9bd7cec)}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,ModelingToolkit.var"#_jac#155"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0x0bc4f15d, 0x292087c0, 0xe0782571, 0x4a2d3328, 0x56cd814c)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#783"), Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xd43242d8, 0x38ee3feb, 0x71673919, 0x67d7d6b0, 0x9bc144db)}},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Symbol,1},Symbol,ModelingToolkit.var"#135#generated_observed#158"{Bool,ODESystem,Dict{Any,Any}},Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}, ::Array{Float64,1}, ::Array{Array{Float64,1},1}) at /Users/lmorton/.julia/packages/SciMLBase/XuLdB/src/solutions/solution_interface.jl:569
[13] solplot_vecs_and_labels(::Int64, ::Array{Tuple,1}, ::Array{Array{Float64,1},1}, ::Array{Float64,1}, ::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,ModelingToolkit.var"#f#151"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xee580426, 0x65e1a321, 0x4af8a7bc, 0xd22398a2, 0x8ce03fe0)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#780"), Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xa36e7a51, 0x79c3d7e7, 0x45eeda34, 0x3e16c63c, 0xd9bd7cec)}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,ModelingToolkit.var"#_jac#155"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0x0bc4f15d, 0x292087c0, 0xe0782571, 0x4a2d3328, 0x56cd814c)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#783"), Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xd43242d8, 0x38ee3feb, 0x71673919, 0x67d7d6b0, 0x9bc144db)}},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Symbol,1},Symbol,ModelingToolkit.var"#135#generated_observed#158"{Bool,ODESystem,Dict{Any,Any}},Nothing},Base.Iterators.Pairs{Symbol,Bool,Tuple{Symbol},NamedTuple{(:jac,),Tuple{Bool}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,ModelingToolkit.var"#f#151"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xee580426, 0x65e1a321, 0x4af8a7bc, 0xd22398a2, 0x8ce03fe0)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#780"), Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xa36e7a51, 0x79c3d7e7, 0x45eeda34, 0x3e16c63c, 0xd9bd7cec)}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,ModelingToolkit.var"#_jac#155"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0x0bc4f15d, 0x292087c0, 0xe0782571, 0x4a2d3328, 0x56cd814c)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#783"), Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xd43242d8, 0x38ee3feb, 0x71673919, 0x67d7d6b0, 0x9bc144db)}},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Symbol,1},Symbol,ModelingToolkit.var"#135#generated_observed#158"{Bool,ODESystem,Dict{Any,Any}},Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}, ::Bool, ::Nothing, ::Array{String,1}) at /Users/lmorton/.julia/packages/SciMLBase/XuLdB/src/solutions/solution_interface.jl:579
[14] diffeq_to_arrays(::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,ModelingToolkit.var"#f#151"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xee580426, 0x65e1a321, 0x4af8a7bc, 0xd22398a2, 0x8ce03fe0)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#780"), Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xa36e7a51, 0x79c3d7e7, 0x45eeda34, 0x3e16c63c, 0xd9bd7cec)}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,ModelingToolkit.var"#_jac#155"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0x0bc4f15d, 0x292087c0, 0xe0782571, 0x4a2d3328, 0x56cd814c)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#783"), Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xd43242d8, 0x38ee3feb, 0x71673919, 0x67d7d6b0, 0x9bc144db)}},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Symbol,1},Symbol,ModelingToolkit.var"#135#generated_observed#158"{Bool,ODESystem,Dict{Any,Any}},Nothing},Base.Iterators.Pairs{Symbol,Bool,Tuple{Symbol},NamedTuple{(:jac,),Tuple{Bool}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,ModelingToolkit.var"#f#151"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xee580426, 0x65e1a321, 0x4af8a7bc, 0xd22398a2, 0x8ce03fe0)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#780"), Symbol("##arg#778"), Symbol("##arg#779"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xa36e7a51, 0x79c3d7e7, 0x45eeda34, 0x3e16c63c, 0xd9bd7cec)}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,ModelingToolkit.var"#_jac#155"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0x0bc4f15d, 0x292087c0, 0xe0782571, 0x4a2d3328, 0x56cd814c)},RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(Symbol("##out#783"), Symbol("##arg#781"), Symbol("##arg#782"), :t),ModelingToolkit.var"#_RGF_ModTag",ModelingToolkit.var"#_RGF_ModTag",(0xd43242d8, 0x38ee3feb, 0x71673919, 0x67d7d6b0, 0x9bc144db)}},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Symbol,1},Symbol,ModelingToolkit.var"#135#generated_observed#158"{Bool,ODESystem,Dict{Any,Any}},Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}, ::Bool, ::Bool, ::Int64, ::Nothing, ::Float64, ::Array{Num,1}, ::Array{Tuple,1}, ::Symbol, ::Array{String,1}) at /Users/lmorton/.julia/packages/SciMLBase/XuLdB/src/solutions/solution_interface.jl:373
[15] macro expansion at /Users/lmorton/.julia/packages/SciMLBase/XuLdB/src/solutions/solution_interface.jl:161 [inlined]
[16] apply_recipe(::AbstractDict{Symbol,Any}, ::SciMLBase.AbstractTimeseriesSolution) at /Users/lmorton/.julia/packages/RecipesBase/92zOw/src/RecipesBase.jl:282
[17] _process_userrecipes!(::Any, ::Any, ::Any) at /Users/lmorton/.julia/packages/RecipesPipeline/CirY4/src/user_recipe.jl:36
[18] recipe_pipeline!(::Any, ::Any, ::Any) at /Users/lmorton/.julia/packages/RecipesPipeline/CirY4/src/RecipesPipeline.jl:70
[19] _plot!(::Plots.Plot, ::Any, ::Any) at /Users/lmorton/.julia/packages/Plots/g581z/src/plot.jl:172
[20] plot(::Any; kw::Any) at /Users/lmorton/.julia/packages/Plots/g581z/src/plot.jl:58
[21] top-level scope at REPL[167]:1
In Julia 1.5.3:
[23fbe1c1] Latexify v0.15.5
[961ee093] ModelingToolkit v5.17.0
[8913a72c] NonlinearSolve v0.3.8
[1dea7af3] OrdinaryDiffEq v5.53.2
[91a5bcdd] Plots v1.15.0
The problem is with the destructuring of the solution vector at a single timeslice. If I add a second state to the system, I get “About to run: (tanh)([0.3, 10.0])” from the debugger. The two elements correspond to the two states.
I stepped through execution for both the working & broken cases. The first time I ran into a difference was at
the call to solplot_vecs_and_labels
inside diffeq_to_arrays
.
In the broken case (using the saveat
), the plot_timeseries
argument is a simple array
plot_timeseries: 101-element Array{Array{Float64,1},1}
[0.3].
[0.2991349301091002]
[0.29827713073682965]
[0.29742650234546386]
[0.2965829470988183]
⋮
[0.23974233351232516]
[0.23929456829531304]
[0.23884926734057982]
[0.23840640927145948]
In the case that works (w/o saveat
):
>typeof(plot_timeseries)
DiffEqArray{Float64,2,Array{Array{Float64,1},1},Array{Float64,1}}
>plot_timeseries
t: 1000-element Array{Float64,1}:
0.0
0.001001001001001001
⋮
0.998998998998999
1.0
u: 1000-element Array{Array{Float64,1},1}:
[0.3, 10.0]
[0.29991322193784203, 10.001001001001]
⋮
[0.3592183230394975, 10.998998998998996]
[0.35941591290813224, 11.0]
In the broken case, the argument is produced by plot_timeseries = sol.u[start_idx:end_idx]
at solution_interface.jl:360.
If I overwrite by doing plot_timeseries = sol(plott)
instead (as in 333) and then continue, the error doesn't happen. That proves that this is where the error stems from.
It looks like the difference is caused by denseplot
being true
in the working case, false
in the broken case, which leads to the different ways to grab plot_timeseries
. I found that if step up one layer to the calling environment for diffeq_to_arrays
in the broken trace, I can make it run by switching denseplot=true
, and also I can break the working execution by setting 'denseplot = false`.
The call to diffeq_to_arrays
is happening inside some generated code for the recipe that looks like this:
In apply_recipe(plotattributes, sol) at /Users/lmorton/.julia/packages/RecipesBase/92zOw/src/RecipesBase.jl:275
119 │ │ syms = (SciMLBase.getsyms)(sol)
120 │ │ int_vars = (SciMLBase.interpret_vars)(vars, sol, syms)
121 │ │ strs = (SciMLBase.cleansyms)(syms)
122 │ │ tscale = (get)(plotattributes, :xscale, :identity)
>123 │ │ %123 = (SciMLBase.diffeq_to_arrays)(sol, plot_analytic, denseplot, plotdensity, tspan, axis_safety, vars, int_vars, tscale, strs)
124 │ │ %124 = (Base.indexed_iterate)(%123, 1)
125 │ │ plot_vecs = (getfield)(%124, 1)
126 │ │ @_4 = (getfield)(%124, 2)
127 │ │ %127 = (Base.indexed_iterate)(%123, 2, @_4)
I'm not sure how to backtrack and see how this code is being generated. In any event, I think the problem is not with the value of denseplot,
but with the fact that the plot_timeseries
type instability between the two cases. My thought is to patch the denseplot == false
branch to produce a DiffEqArray
and see if that fixes things.
@bgctw's problem is cropping up inside the get_index method of RecursiveArrayTools:4
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i)
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
(Δ′,nothing)
end
VA[i],AbstractVectorOfArray_getindex_adjoint
end
Somewhere upstream, i
is getting set to dt(x)::Num
. Then i==j
evaluates to dt(x)==1
which is still Num
-type, not Boolean.
Two ways to solve this:
1:length(VA)
to something like eachindex
to preserve indexing by symbol-like things.i<:Integer
& make sure that this gets respected upstream. If we went with (1), then because dx(t) == dx(t)
doesn't evaluate to a Boolean either, we'd needi===j
instead.
Constrain the call signature to i<:Integer & make sure that this gets respected upstream.
This might be correct? Or we need to declare to @nograd
on the symbolic indexing.
I whittled the MWE down to:
function breaky(sol)
sol[dx][1]
end
stuff,back = Zygote._pullback(breaky,p0)
back(1)
The stacktrace is:
ERROR: TypeError: non-boolean (Num) used in boolean context
Stacktrace:
[1] (::RecursiveArrayTools.var"#138#140"{Array{Int64,1},Num})(::Tuple{Array{Float64,1},Int64}) at ./array.jl:0
[2] iterate at ./generator.jl:47 [inlined]
[3] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{Array{Float64,1},1},UnitRange{Int64}}},RecursiveArrayTools.var"#138#140"{Array{Int64,1},Num}}) at ./array.jl:686
[4] AbstractVectorOfArray_getindex_adjoint at /Users/lmorton/Code/Jdev/RecursiveArrayTools/src/zygote.jl:3 [inlined]
[5] #4583#back at /Users/lmorton/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[6] breaky at ./REPL[28]:2 [inlined]
[7] (::typeof(∂(breaky)))(::Int64) at /Users/lmorton/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[8] top-level scope at REPL[42]:1
If I replace the index dx
above with 1
, it works. The problem is not related to saveat
. I tried stepping through the execution of sol[dx]
to see what it looks like when it's working. The getindex
dispatches to solution_interface:
@enter sol1[dx]
In ##thunk#1862() at REPL[41]:1
>1 1 ─ %1 = (getindex)(Main.sol1, Main.dx)
2 └── return %1
About to run: <(getindex)(t: [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1...>
1|debug> s
In getindex(A, sym) at /Users/lmorton/Code/Jdev/SciMLBase/src/solutions/solution_interface.jl:38
38 Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution,sym)
>39 if issymbollike(sym)
40 i = sym_to_index(sym,A)
41 else
42 i = sym
43 end
This never happens in the pullback that Zygote produces, so i
never gets converted to an integer.
In ZygoteRules for RescursiveArrayTools, what we need is a zero
method that preserves the ability to index by symbol. That way, the indexing that's happening here can be delegated to the appropriate getindex
without introducing any logic inside RecursiveArrayTools that needs to be aware of symbolic indexing.
Two thoughts:
SciMLSolutions
, then specialize zero
on that traitMaybe it would be cleaner to use StructArrays or NamedDims for symbolics
That will definitely not be cleaner. We do not want type information on this because you can have millions of symbols: this would cause segfaults.
Sorry, I'm not sure I follow. Maybe I didn't explain myself very clearly. My thought was that we could substitute something like StructArray
(or maybe AxisArray
would be more performant?) for Array{Array{T,1}1}
as the type of sol.u
. That way, sol.u
would carry its own information about how to do the symbolic indexing, and we could simplify solution_interface.jl
greatly, and maybe RecursiveArrayTools
as well.
So what you're saying is that these packages would be slow if sol.u
had millions of symbols to index over (ex: making a PDE as a flattened array of equations)?
So what you're saying is that these packages would be slow if sol.u had millions of symbols to index over (ex: making a PDE as a flattened array of equations)?
Yes, trying to move that to type information will segfault Julia. It's runtime information for a reason. LabelledArrays.jl is a good idea for small equations but will not scale.
Ahh, thanks, I understand the problem now. Just defining that many variables causes a crash.
How about AxisArrays then? It's slightly less convenient to use sol[:x]
vs sol.x
but it performs fine. (Perhaps one could hack getproperty
to make both work?) AxisArrays would also make sense for PDEs where one would rather name the dimensions than the elements.
How about AxisArrays then? It's slightly less convenient to use sol[:x] vs sol.x but it performs fine. (Perhaps one could hack getproperty to make both work?) AxisArrays would also make sense for PDEs where one would rather name the dimensions than the elements.
Again, the whole purpose of AxisArrays is to specialize on small symbol numbers directly into type information:
https://github.com/JuliaArrays/AxisArrays.jl/blob/master/src/core.jl#L5-L62
which will segfault Julia on any large model. I think you're mixing implementation with interface. The interface can look similar to AxisArrays, but the implementation cannot.
Again, the whole purpose of AxisArrays is to specialize on small symbol numbers directly into type information:
I see. So having a large number of dimensions would be problematic. What I had in mind was slightly different: packing many symbolic indexes along a single dimension. Like this:
N = 10^6 #Number of variables
M=5 #length of the solution in time
irange = range(1,length=N)
srange = map(num->Symbol("v$num"),irange);
data = Random.rand(N,M);
drange = range(1,length=M)
sarr = AxisArray(data,srange,drange)
sarr[:v202105]
It's a little clunky, but it works. I find it hard to imagine needing ~10^N dimensions with large N, but maybe I'm just not thinking big enough.
I think you're mixing implementation with interface.
Yeah, I guess you're right. My instinct is to want them to match, and it also seemed like a quick way to get a more complete array interface. But, if it won't work, it won't work.
The other thing is that the array cannot be fully populated. In many cases, we only solve 100 equations for the symbolic system of 10,000 equations, and use the observed
function to rebuild the un-calculated values on demand only when the user asks for them.
Oh, right. And MappedArrays won't help either, because it makes the restrictive assumption of 1-to-1 mapping of elements.
I found a workaround, using SciMLBase.observed
to produce the observed values.
function loss(p)
solp = solve(prob, Rodas5(), p=p, saveat = tsteps)
dxval = SciMLBase.observed(solp,dx,:)
lossval = sum(abs2,dxval .- obsdt)
return lossval, solp
end
This suggests that all we need to do is prevent the ZygoteRule for AbstractVectorOfArray
from applying to getindex
for one of these subtypes:
julia> supertypes(ODESolution)
(ODESolution, SciMLBase.AbstractODESolution, SciMLBase.AbstractTimeseriesSolution, AbstractDiffEqArray, AbstractVectorOfArray, AbstractArray, Any)
We can close this once SciML/RecursiveArrayTools.jl#151 lands.
It landed ✈️
BTW, I found that this fix doesn't help with Zygote, but it does allow ForwardDiff to work, which allows sciml_train to succeed with AD.
Interesting. Could we get an issue open on that?
I think this issue demonstrates it well. It still fails with current RecursiveArrayTools
.
Accessing the ODESolution for an observed variable sometimes fails when the solution has been generated by
solve(..., saveat = tsteps)
. This happens for plotting, but more severely it imparts optimizations that needs the gradient of the loss function. (I think, but I started learning about DiffEqFlux only recently.)The following example demonstrates the behavior. It was generated using Julia Version 1.6.0 OS: Linux (x86_64-pc-linux-gnu) [2b5f629d] DiffEqBase v6.58.0 [aae7a2af] DiffEqFlux v1.36.0 [587475ba] Flux v0.12.1 [a75be94c] GalacticOptim v1.1.0 [961ee093] ModelingToolkit v5.14.3 [0c5d862f] Symbolics v0.1.18 [e88e6eb3] Zygote v0.6.9
Throwing the following error: