SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
870 stars 156 forks source link

neural_ode_sciml example fails when Dense layer replaced by GRU #432

Closed John-Boik closed 2 years ago

John-Boik commented 4 years ago

As a first step leading up to GRU-ODE or ODE-LSTM implementations, I'd like to switch out the Dense layer in the neural_ode_sciml example with a GRU layer. However, doing so raises the error LoadError: DimensionMismatch("array could not be broadcast to match destination"). I don't understand where the problem is occuring, exactly, or how to fix it. Any ideas?

Code is as follows, with the main differences from the original example being:

This issue is loosely related to Training of UDEs with recurrent networks #391 and Flux.destructure doesn't preserve RNN state #1329. See also ODE-LSTM layer #422 .

The code is as follows, with the Dense layer commented out and replaced by the GRU layer:

module TestDiffeq3b

using Revise
using Infiltrator
using Formatting

import DiffEqFlux
import OrdinaryDiffEq 
import Flux
import Optim
import Plots

u0 = Float32[2.0; 0.0]    
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

function trueODEfunc(du, u, p, t)    
    true_A = [-0.1 2.0; -2.0 -0.1]    
    du .= ((u.^3)'true_A)'    
end

prob_trueode = OrdinaryDiffEq.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(OrdinaryDiffEq.solve(prob_trueode, OrdinaryDiffEq.Tsit5(), saveat = tsteps))    

dudt2 = Flux.Chain(
    x -> x.^3,
    Flux.Dense(2, 50, tanh),
    #Flux.Dense(50, 2)
    Flux.GRU(50, 2)
    )

p, re = Flux.destructure(dudt2)  
neural_ode_f(u, p, t) = re(p)(u)

prob = OrdinaryDiffEq.ODEProblem(neural_ode_f, u0, tspan, p)

function predict_neuralode(p)
    tmp_prob = OrdinaryDiffEq.remake(prob,p=p)
    res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps))
    return res
end

function loss_neuralode(p)
        pred = predict_neuralode(p)  # (2,30)
        loss = sum(abs2, ode_data .- pred)  # scalar
        return loss, pred
end

callback = function (p, l, pred; doplot = true)
    display(l)
    # plot current prediction against data
    plt = Plots.scatter(tsteps, ode_data[1,:], label = "data")
    Plots.scatter!(plt, tsteps, pred[1,:], label = "prediction")
    if doplot
        display(Plots.plot(plt))
    end
    return false
end

result_neuralode = DiffEqFlux.sciml_train(
    loss_neuralode, 
    p,
    Flux.ADAM(0.05), 
    cb = callback,
    maxiters = 300
    )

result_neuralode2 = DiffEqFlux.sciml_train(
    loss_neuralode,
    result_neuralode.minimizer,
    Optim.LBFGS(),
    cb = callback,
    allow_f_increases = false
    )

end    # ------------------------------- module -----------------------------------

The error message is:


ERROR: LoadError: DimensionMismatch("array could not be broadcast to match destination")
Stacktrace:
 [1] check_broadcast_shape at ./broadcast.jl:520 [inlined]
 [2] check_broadcast_axes at ./broadcast.jl:523 [inlined]
 [3] instantiate at ./broadcast.jl:269 [inlined]
 [4] materialize! at ./broadcast.jl:848 [inlined]
 [5] materialize!(::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(identity),Tuple{Array{Float32,1}}}) at ./broadcast.jl:845
 [6] _vecjacobian!(::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Array{Float32,1}, ::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Array{Float32,1}, ::Float32, ::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}}, ::DiffEqSensitivity.ZygoteVJP, ::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Nothing) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/derivative_wrappers.jl:296
 [7] _vecjacobian! at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/derivative_wrappers.jl:193 [inlined]
 [8] #vecjacobian!#20 at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/derivative_wrappers.jl:147 [inlined]
 [9] (::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}})(::Array{Float32,1}, ::Array{Float32,1}, ::Array{Float32,1}, ::Float32) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/interpolating_adjoint.jl:145
 [10] ODEFunction at /home/jboik/.julia/packages/DiffEqBase/gLFRA/src/diffeqfunction.jl:248 [inlined]
 [11] initialize!(::OrdinaryDiffEq.ODEIntegrator{OrdinaryDiffEq.Tsit5,true,Array{Float32,1},Nothing,Float32,Array{Float32,1},Float32,Float32,Float32,Array{Array{Float32,1},1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},DiffEqBase.ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats},DiffEqBase.ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float32,Float32,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float32,Base.Order.ForwardOrdering},DataStructures.BinaryHeap{Float32,Base.Order.ForwardOrdering},Nothing,Nothing,Int64,Array{Float32,1},Array{Float32,1},Tuple{}},Array{Float32,1},Float32,Nothing,OrdinaryDiffEq.DefaultInit}, ::OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}) at /home/jboik/.julia/packages/OrdinaryDiffEq/HO8vN/src/perform_step/low_order_rk_perform_step.jl:623
 [12] __init(::DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},DiffEqBase.ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem}, ::OrdinaryDiffEq.Tsit5, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{Val{true}}; saveat::Array{Float32,1}, tstops::Array{Float32,1}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}}, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Float64, reltol::Float64, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, qoldinit::Rational{Int64}, fullnormalize::Bool, failfactor::Int64, beta1::Nothing, beta2::Nothing, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/jboik/.julia/packages/OrdinaryDiffEq/HO8vN/src/solve.jl:428
 [13] #__solve#391 at /home/jboik/.julia/packages/OrdinaryDiffEq/HO8vN/src/solve.jl:4 [inlined]
 [14] solve_call(::DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},DiffEqBase.ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem}, ::OrdinaryDiffEq.Tsit5; merge_callbacks::Bool, kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{6,Symbol},NamedTuple{(:save_everystep, :save_start, :saveat, :tstops, :abstol, :reltol),Tuple{Bool,Bool,Array{Float32,1},Array{Float32,1},Float64,Float64}}}) at /home/jboik/.julia/packages/DiffEqBase/gLFRA/src/solve.jl:65
 [15] #solve_up#458 at /home/jboik/.julia/packages/DiffEqBase/gLFRA/src/solve.jl:86 [inlined]
 [16] #solve#457 at /home/jboik/.julia/packages/DiffEqBase/gLFRA/src/solve.jl:74 [inlined]
 [17] _adjoint_sensitivities(::DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}, ::DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool}, ::OrdinaryDiffEq.Tsit5, ::DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon}, ::Array{Float32,1}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float32,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/sensitivity_interface.jl:22
 [18] _adjoint_sensitivities(::DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}, ::DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool}, ::OrdinaryDiffEq.Tsit5, ::Function, ::Array{Float32,1}, ::Nothing) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/sensitivity_interface.jl:13 (repeats 2 times)
 [19] adjoint_sensitivities(::DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}, ::OrdinaryDiffEq.Tsit5, ::Vararg{Any,N} where N; sensealg::DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/sensitivity_interface.jl:6
 [20] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{OrdinaryDiffEq.Tsit5,DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},Array{Float32,1},Tuple{},Colon})(::Array{Float32,2}) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/concrete_solve.jl:144
 [21] #673#back at /home/jboik/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [22] #145 at /home/jboik/.julia/packages/Zygote/chgvX/src/lib/lib.jl:175 [inlined]
 [23] (::Zygote.var"#1681#back#147"{Zygote.var"#145#146"{DiffEqBase.var"#673#back#471"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{OrdinaryDiffEq.Tsit5,DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},Array{Float32,1},Tuple{},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Float32,2}) at /home/jboik/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [24] #solve#457 at /home/jboik/.julia/packages/DiffEqBase/gLFRA/src/solve.jl:74 [inlined]
 [25] (::typeof(∂(#solve#457)))(::Array{Float32,2}) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [26] (::Zygote.var"#145#146"{typeof(∂(#solve#457)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::Array{Float32,2}) at /home/jboik/.julia/packages/Zygote/chgvX/src/lib/lib.jl:175
 [27] (::Zygote.var"#1681#back#147"{Zygote.var"#145#146"{typeof(∂(#solve#457)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Float32,2}) at /home/jboik/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [28] (::typeof(∂(solve##kw)))(::Array{Float32,2}) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [29] predict_neuralode at /home/jboik/Devel/Ideai_Ju/Ideai/examples/Load_01/load_fake/TestDiffEq3b.jl:70 [inlined]
 [30] (::typeof(∂(predict_neuralode)))(::Array{Float32,2}) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [31] loss_neuralode at /home/jboik/Devel/Ideai_Ju/Ideai/examples/Load_01/load_fake/TestDiffEq3b.jl:76 [inlined]
 [32] #145 at /home/jboik/.julia/packages/Zygote/chgvX/src/lib/lib.jl:175 [inlined]
 [33] (::Zygote.var"#1681#back#147"{Zygote.var"#145#146"{typeof(∂(loss_neuralode)),Tuple{Tuple{Nothing},Tuple{}}}})(::Tuple{Float32,Nothing}) at /home/jboik/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [34] #74 at /home/jboik/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:120 [inlined]
 [35] (::typeof(∂(λ)))(::Float32) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [36] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177
 [37] gradient(::Function, ::Zygote.Params) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54
 [38] macro expansion at /home/jboik/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:119 [inlined]
 [39] macro expansion at /home/jboik/.julia/packages/ProgressLogging/BBN0b/src/ProgressLogging.jl:328 [inlined]
 [40] (::DiffEqFlux.var"#73#78"{Main.TestDiffeq3b.var"#3#5",Int64,Bool,Bool,typeof(Main.TestDiffeq3b.loss_neuralode),Array{Float32,1},Zygote.Params})() at /home/jboik/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:64
 [41] with_logstate(::Function, ::Any) at ./logging.jl:408
 [42] with_logger at ./logging.jl:514 [inlined]
 [43] maybe_with_logger(::DiffEqFlux.var"#73#78"{Main.TestDiffeq3b.var"#3#5",Int64,Bool,Bool,typeof(Main.TestDiffeq3b.loss_neuralode),Array{Float32,1},Zygote.Params}, ::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{TerminalLoggers.TerminalLogger,DiffEqFlux.var"#68#70"},LoggingExtras.EarlyFilteredLogger{Logging.ConsoleLogger,DiffEqFlux.var"#69#71"}}}) at /home/jboik/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:39
 [44] sciml_train(::Function, ::Array{Float32,1}, ::Flux.Optimise.ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at /home/jboik/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:63
 [45] top-level scope at /home/jboik/Devel/Ideai_Ju/Ideai/examples/Load_01/load_fake/TestDiffEq3b.jl:93
 [46] include(::String) at ./client.jl:457
 [47] top-level scope at REPL[2]:1
in expression starting at /home/jboik/Devel/Ideai_Ju/Ideai/examples/Load_01/load_fake/TestDiffEq3b.jl:93
ChrisRackauckas commented 4 years ago

Doesn't a GRU have state though, so the model wouldn't be well-defined?

ChrisRackauckas commented 4 years ago

I think @avik-pal and @DhairyaLGandhi have mentioned something about destructure giving different arguments out when a layer has state, which is a bit weird. Could one of you give some input on that? I think that would be someone to fix up on the Flux side, even if it would be a breaking change, making the outputs out of that function uniform and documented would fix issues like this.

John-Boik commented 4 years ago

If it's of help, learning-long-term-irregular-ts shows (starting on line 566) code for the GRUODE written in python.

ChrisRackauckas commented 4 years ago

Note that method is only going to be compatible with adaptive=false because otherwise the state makes the ODE undefined. I think all you need is to turn off adaptivity and whatever that different destructure is.

John-Boik commented 4 years ago

@avik-pal and @DhairyaLGandhi, note that the solve function runs properly and produces anticipated output. The DimensionMismatch error occurs later when gradients are taken. Also, the same error occurs when using adaptive=false: OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, adaptive=false, dt=.5), instead of the solve() in the code above.

avik-pal commented 4 years ago

The destructure issue Chris mentioned above should not lead to dimension mismatch error. It just makes the GRU work without any recurrence, as the state is overwritten every time we do re(p). (@DhairyaLGandhi do you know how to fix this?)

The exact source of the error you encounter seems to be the sensitivity algorithm. A quick fix would be:

res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, sensealg=InterpolatingAdjoint(autojacvec=false)))
DhairyaLGandhi commented 4 years ago

We don't close over a number of arguments in destructure, which may be necessary for our restructure case as well. Adding those back to our cache, which can be passed around to the restrcture could do it.

function destructure(m; cache = IdDict())
  xs = Zygote.Buffer([])
  fmap(m) do x
    if x isa AbstractArray
      push!(xs, x)
    else
      cache[x] = x
    end
    return x
  end
  return vcat(vec.(copy(xs))...), p -> _restructure(m, p, cache = cache)
end

function _restructure(m, XS; cache = IdDict())
  i = 0
  fmap(m) do x
    x isa AbstractArray || return cache[x]
    x = reshape(xs[i.+(1:length(x))], size(x))
    i += length(x)
    return x
  end
end

This is untested currently, @avik-pal would something like this solve the specific issue you're talking about?

John-Boik commented 4 years ago

Thanks. I can verify that the following works with Flux.GRU. I used autojacvec=true rather than false, and it seems to run a bit faster that way.

module TestDiffeq3bb

using Revise
using Infiltrator
using Formatting

import DiffEqFlux
import OrdinaryDiffEq 
import DiffEqSensitivity
import Flux
import Optim
import Plots
import Zygote
import Functors

u0 = Float32[2.0; 0.0]    
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
iter = 0

function trueODEfunc(du, u, p, t)    
    true_A = Float32[-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)' 
end

prob_trueode = OrdinaryDiffEq.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(OrdinaryDiffEq.solve(prob_trueode, OrdinaryDiffEq.Tsit5(), saveat = tsteps)) 

dudt2 = Flux.Chain(
    x -> x.^3,
    Flux.Dense(2, 20, tanh),
    Flux.GRU(20, 20),
    Flux.Dense(20, 2),
    )

function destructure(m; cache = IdDict())
  xs = Zygote.Buffer([])
  Functors.fmap(m) do x
    if x isa AbstractArray
      push!(xs, x)
    else
      cache[x] = x
    end
    return x
  end
  return vcat(vec.(copy(xs))...), p -> _restructure(m, p, cache = cache)
end

function _restructure(m, xs; cache = IdDict())
  i = 0
  Functors.fmap(m) do x
    x isa AbstractArray || return cache[x]
    x = reshape(xs[i .+ (1:length(x))], size(x))
    i += length(x)
    return x
  end
end

p, re = destructure(dudt2)

function neural_ode_f(u, p, t)
    return re(p)(u)
end

prob = OrdinaryDiffEq.ODEProblem(neural_ode_f, u0, tspan, p)

function predict_neuralode(p)
    tmp_prob = OrdinaryDiffEq.remake(prob,p=p)
    res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, sensealg=DiffEqSensitivity.InterpolatingAdjoint(autojacvec=true)))
    return res
end

function loss_neuralode(p)
        pred = predict_neuralode(p)
        loss = sum(abs2, ode_data .- pred) 
        return loss, pred
end

callback = function (p, l, pred; doplot = true)
    global iter
    iter += 1
    @show iter, l
    # plot current prediction against data
    plt = Plots.scatter(tsteps, ode_data[1,:], label = "data", title=string(iter))
    Plots.scatter!(plt, tsteps, pred[1,:], label = "prediction")
    if doplot
        display(Plots.plot(plt))
    end
    return false
end

result_neuralode = DiffEqFlux.sciml_train(
    loss_neuralode, 
    p,
    Flux.ADAM(0.05), 
    cb = callback,
    maxiters = 60
    )

end    # ------------------------------- module -----------------------------------
ChrisRackauckas commented 4 years ago

The fixed restructure/destructure works:

import DiffEqFlux
import OrdinaryDiffEq
import Flux
import Optim
import Plots
import Zygote

u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end

prob_trueode = OrdinaryDiffEq.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(OrdinaryDiffEq.solve(prob_trueode, OrdinaryDiffEq.Tsit5(), saveat = tsteps))

dudt2 = Flux.Chain(
    x -> x.^3,
    Flux.Dense(2, 50, tanh),
    #Flux.Dense(50, 2)
    Flux.GRU(50, 2)
    )

sf

function destructure(m; cache = IdDict())
  xs = Zygote.Buffer([])
  Flux.fmap(m) do x
    if x isa AbstractArray
      push!(xs, x)
    else
      cache[x] = x
    end
    return x
  end
  return vcat(vec.(copy(xs))...), p -> _restructure(m, p, cache = cache)
end

function _restructure(m, xs; cache = IdDict())
  i = 0
  Flux.fmap(m) do x
    x isa AbstractArray || return cache[x]
    x = reshape(xs[i.+(1:length(x))], size(x))
    i += length(x)
    return x
  end
end

p, re = destructure(dudt2)
neural_ode_f(u, p, t) = re(p)(u)

prob = OrdinaryDiffEq.ODEProblem(neural_ode_f, u0, tspan, p)

function predict_neuralode(p)
    tmp_prob = OrdinaryDiffEq.remake(prob,p=p)
    res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, dt=0.01, adaptive=false))
    return res
end

function loss_neuralode(p)
        pred = predict_neuralode(p)  # (2,30)
        loss = sum(abs2, ode_data .- pred)  # scalar
        return loss, pred
end

callback = function (p, l, pred; doplot = true)
    display(l)
    # plot current prediction against data
    plt = Plots.scatter(tsteps, ode_data[1,:], label = "data")
    Plots.scatter!(plt, tsteps, pred[1,:], label = "prediction")
    if doplot
        display(Plots.plot(plt))
    end
    return false
end

result_neuralode = DiffEqFlux.sciml_train(
    loss_neuralode,
    p,
    Flux.ADAM(0.05),
    cb = callback,
    maxiters = 3000
    )

result_neuralode2 = DiffEqFlux.sciml_train(
    loss_neuralode,
    result_neuralode.minimizer,
    Flux.ADAM(0.05),
    cb = callback,
    maxiters = 1000,
    )

The method isn't very good, but it does what you asked for.

John-Boik commented 4 years ago

Excellent @ChrisRackauckas. I see that sensealg=DiffEqSensitivity.InterpolatingAdjoint() is not needed, and dt=0.01, adaptive=false can be used.

The method is a means to an end (eventually, GRU-ODE), but it does work reasonably well as is if Flux.GRU(50, 2) is replaced by Flux.GRU(50, 50), Flux.Dense(50, 2). At 200 iterations of just SGD, the error is about 0.18. If I switch in my custom GRU (below, as per the GRU-ODE paper), and reduce the hidden layer size to 20 from 50, the error is about 0.02 at 200 iterations. Both are less than the error of about 0.48 achieved with a chain of Flux.Dense(2, 50, tanh), Flux.Dense(50, 2).

The custom GRU is as follows. In this problem, I use Flux2.GRU2(20, true, x -> x ) when defining the chain.

module Flux2

import Flux
using Infiltrator

mutable struct GRUCell2{W,U,B,L,TF,F}
    update_W::W
    update_U::U
    update_b::B

    reset_W::W
    reset_U::U
    reset_b::B

    out_W::W
    out_U::U
    out_b::B

    H::L
    is_dhdt::TF
    fx::F
end

GRUCell2(L, is_output_dhdt, fx; init = Flux.glorot_uniform) =
    GRUCell2(
        init(L, L), 
        init(L, L),
        init(L,1), 

        init(L, L), 
        init(L, L),
        init(L,1), 

        init(L, L), 
        init(L, L),
        init(L,1), 

        zeros(Float32, (L,1)),
        is_output_dhdt,
        fx
        )

function (m::GRUCell2)(H, X)
    update_gate = Flux.sigmoid.( 
            (m.update_W * X) 
            .+ (m.update_U * m.H) 
            .+ m.update_b)

    reset_gate = Flux.sigmoid.( 
            (m.reset_W * X) 
            .+ (m.reset_U * m.H) 
            .+ m.reset_b)

    output_gate = m.fx.( 
            (m.out_W * X) 
            .+ (m.out_U * (reset_gate .* m.H))
            .+ m.out_b)

    if m.is_dhdt == true
        # output is dhdt
        output =  (Float32(1) .- update_gate) .* (output_gate .- H)
    else
        # standard GRU output
        output = ((Float32(1) .- update_gate) .* output_gate) .+ (update_gate .* H)
    end

    H = output
    return H, H
end

Flux.hidden(m::GRUCell2) = m.H
Flux.@functor GRUCell2

Base.show(io::IO, l::GRUCell2) =
    print(io, "GRUCell2(", size(l.update_W, 2), ", ", size(l.update_W, 1), ")")

GRU2(a...; ka...) = Flux.Recur(GRUCell2(a...; ka...))

end  # -----------------------------  module
ChrisRackauckas commented 4 years ago

Cool yeah. The other thing to try is sensealg=ReverseDiffAdjoint(). Using direct reverse-mode AD might be better if it's fixed time step since that would not have the same possibility of having adjoint error like the continuous adjoints, which would be more of an issue if it's not adaptive on the reverse.

ChrisRackauckas commented 4 years ago

It would be good to turn this into a tutorial when all is said and done. @DhairyaLGandhi could you add that restructure/destructure patch to Flux and then tag a release? @John-Boik would you be willing to contribute a tutorial?

ChrisRackauckas commented 4 years ago

Or @mkg33 might be able to help out here.

John-Boik commented 4 years ago

Sure, I would be happy to help if I can.

mkg33 commented 4 years ago

Of course, I'll add it to my tasks.

sungjuGit commented 3 years ago

Has this "fix" been released to FluxML yet?

sungjuGit commented 3 years ago

ODE-LSTM implementations @John-Boik, Are you or others still working on ODE-LSTM implementation in FluxML?

John-Boik commented 3 years ago

I'm working on similar models, which also use re/de structure. The fix to re/de structure has been released, and both functions are working fine as far as I know.

ChrisRackauckas commented 3 years ago

I think @DhairyaLGandhi didn't merge the fix yet https://github.com/FluxML/Flux.jl/pull/1353

It's still a bad model though.

ChrisRackauckas commented 2 years ago

This was fixed by https://github.com/FluxML/Flux.jl/pull/1901, and one can now use Lux which makes the state explicit. Cheers!