SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
865 stars 153 forks source link

Broadcasting problem in Zygote.gradient with FastChain() #171

Closed cems2 closed 4 years ago

cems2 commented 4 years ago

In a NeuralODE problem I am testing two trivial models

model1 = FastDense(2,2)
model2 = FastChain(model1)

I believe these are nominally identical models. Yet when called in a batch mode, the Zygote Gradient of my loss function fails for the second model but not the first. The failure ONLY occurs when broadcasting over a batch.

code to reproduce:

using Flux 
using OrdinaryDiffEq
using DiffEqFlux

model1 = FastDense(2,2)
model2 = FastChain(model1)

#initial conditions
u0 = Float32[1.0,1.0]
minibatch = [(u0[:,:],)]  # note this is 2x1.  this matters to error

# test both models
function do_it()
   for model in (model1,model2)
      println("\nmodel ",model)
      p = initial_params(model)

      #set up ODE Problem
      f(x,p,t) = model(x,p) # error tracebakc juno highlits this in red
      prob = ODEProblem(f,minibatch[1][1],(0.0f0,1.0f0),p)  # u0 is a placeholder.
      pred(u,p) = concrete_solve(prob,Tsit5(),u,p, saveat=0.01)

      #validate prediction working
      pred(u0,p)
      pred(minibatch[1][1],p)

      # set up loss function
      # setup loss function
      loss_vec_batch(p,mb) =  sum(abs, pred(mb,p).- mb) # vectorized over batch
      loss_vec(p) = loss_vec_batch(p,minibatch[1]...)
      loss = loss_vec # rename

      #validate Loss Function working
      loss_vec_batch(p,minibatch[1]...) == loss_vec(p)

      # set up the gradient parameters
      ps = Flux.params(p)

      #use zygote on vector case
      gsv = Flux.Zygote.gradient(ps) do
         x = loss(p)
         first(x)
       end

       println("\nMODEL RAN CLEANLY.  ",gsv.grads)

   end
end

do_it()

Analysis The results of running this code are pasted in below. What happens is model1 runs cleanly and model2 has an error in the Zygote.gradient.

This is the MWE because there are two critical things required to trigger this error.

  1. use the FastChain() wrapper as shown in model2
  2. Feed the data (initial conditions) in to the NerualODE in batch orientation

By batch orientation I mean: If you only wanted to run one intial condition at a time in the concrete_solver then the initial condition u0 can be a simple vector. But if you wanted to run multiple initial conditions at the same time then the initial condition is a 2D matrix (a list of initial condition vectors).
In this example, I am using the matrix format-- as though it were a batch-- but only doing it for one initial condition. The error is the same if I tack on additional initial conditions to this matrix. I left it at one to simplify this. The conversion step from a vector to a matrix is u0[:,:] and if you want add more initial conditions in batch orientation: minibatch = [([u0 u1 u2],)]

This error does not happen if the loss function loops over the initial conditions solving them one at a time as simple vectors. Thus this is a broadcasting problem.

Regression in the code you can see each component of the process is validated. that is the model, the ode solver, and the loss functions all work fine in this batch oriented input. The error happens in the gradient step. It only happens on the second model, not the first.

Output

julia> do_it()

model FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}(2, 2, identity, DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}(Flux.glorot_uniform, Flux.zeros, 2, 2))

MODEL RAN CLEANLY.  IdDict{Any,Any}(Float32[0.9061968, 1.151401, -0.16663091, -0.63883686, 0.0, 0.0] => Float32[114.181854, 49.90483, 107.62989, 46.514927, 90.29219, 37.373295])

model FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}((FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}(2, 2, identity, DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}(Flux.glorot_uniform, Flux.zeros, 2, 2)),))
ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
 [1] check_broadcast_shape(::Tuple{}, ::Tuple{Base.OneTo{Int64}}) at ./broadcast.jl:506
 [2] check_broadcast_shape(::Tuple{Base.OneTo{Int64}}, ::Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}) at ./broadcast.jl:509
 [3] check_broadcast_axes at /Users/cems/.julia/packages/DiffEqBase/LNYfU/src/diffeqfastbc.jl:19 [inlined]
 [4] check_broadcast_axes at ./broadcast.jl:515 [inlined]
 [5] instantiate at ./broadcast.jl:259 [inlined]
 [6] materialize! at ./broadcast.jl:822 [inlined]
 [7] (::Zygote.var"#1003#1005"{Array{Float32,1},Tuple{UnitRange{Int64}}})(::Array{Float32,2}) at /Users/cems/.julia/packages/Zygote/tJj2w/src/lib/array.jl:43
 [8] (::Zygote.var"#2658#back#999"{Zygote.var"#1003#1005"{Array{Float32,1},Tuple{UnitRange{Int64}}}})(::Array{Float32,2}) at /Users/cems/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [9] (::typeof(∂(applychain)))(::Base.ReshapedArray{Float32,2,SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true},Tuple{}}) at /Users/cems/.julia/packages/DiffEqFlux/RLKTh/src/fast_layers.jl:20
 [10] FastChain at /Users/cems/.julia/packages/DiffEqFlux/RLKTh/src/fast_layers.jl:21 [inlined]
 [11] (::typeof(∂(λ)))(::Base.ReshapedArray{Float32,2,SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true},Tuple{}}) at /Users/cems/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [12] f at /Users/cems/Documents/rusty science fair 2019/Zygote_batch_MWE_sub1_error.jl:23 [inlined]
 [13] (::typeof(∂(λ)))(::Base.ReshapedArray{Float32,2,SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true},Tuple{}}) at /Users/cems/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [14] (::DiffEqBase.var"#613#back#542"{typeof(∂(λ))})(::Base.ReshapedArray{Float32,2,SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true},Tuple{}}) at /Users/cems/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [15] #3 at /Users/cems/.julia/packages/DiffEqSensitivity/ZX2U1/src/derivative_wrappers.jl:112 [inlined]
 [16] (::typeof(∂(λ)))(::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}) at /Users/cems/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#36#37"{typeof(∂(λ))})(::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}) at /Users/cems/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:38
 [18] #vecjacobian!#1(::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Nothing, ::typeof(DiffEqSensitivity.vecjacobian!), ::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Array{Float32,1}, ::Float32, ::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing}) at /Users/cems/.julia/packages/DiffEqSensitivity/ZX2U1/src/derivative_wrappers.jl:114
 [19] (::DiffEqSensitivity.var"#kw##vecjacobian!")(::NamedTuple{(:dgrad,),Tuple{SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}}}, ::typeof(DiffEqSensitivity.vecjacobian!), ::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Array{Float32,1}, ::Float32, ::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing}) at ./none:0
 [20] (::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing})(::Array{Float32,1}, ::Array{Float32,1}, ::Array{Float32,1}, ::Float32) at /Users/cems/.julia/packages/DiffEqSensitivity/ZX2U1/src/local_sensitivity/interpolating_adjoint.jl:79
 [21] (::ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing})(::Array{Float32,1}, ::Vararg{Any,N} where N) at /Users/cems/.julia/packages/DiffEqBase/LNYfU/src/diffeqfunction.jl:229
 [22] initialize!(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float32,1},Float32,Array{Float32,1},Float32,Float32,Float32,Array{Array{Float32,1},1},ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats},ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float32,Float32,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float32,DataStructures.LessThan},DataStructures.BinaryHeap{Float32,DataStructures.LessThan},Nothing,Nothing,Int64,Array{Float32,1},Array{Float32,1},Array{Float32,1}},Array{Float32,1},Float32,Nothing}, ::OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}) at /Users/cems/.julia/packages/OrdinaryDiffEq/8Pn99/src/perform_step/low_order_rk_perform_step.jl:623
 [23] #__init#329(::Array{Float32,1}, ::Array{Float32,1}, ::Array{Float32,1}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}}, ::Bool, ::Bool, ::Float32, ::Float32, ::Float32, ::Bool, ::Bool, ::Rational{Int64}, ::Float64, ::Float64, ::Rational{Int64}, ::Int64, ::Int64, ::Int64, ::Rational{Int64}, ::Bool, ::Int64, ::Nothing, ::Nothing, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Array{Float32,1},1}, ::Array{Float32,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/cems/.julia/packages/OrdinaryDiffEq/8Pn99/src/solve.jl:386
 [24] (::DiffEqBase.var"#kw##__init")(::NamedTuple{(:callback, :save_everystep, :save_start, :saveat, :tstops, :abstol, :reltol),Tuple{CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},Bool,Bool,Array{Float32,1},Array{Float32,1},Float64,Float64}}, ::typeof(DiffEqBase.__init), ::ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Array{Float32,1},1}, ::Array{Float32,1}, ::Array{Any,1}, ::Type{Val{true}}) at ./none:0 (repeats 4 times)
 [25] #__solve#328 at /Users/cems/.julia/packages/OrdinaryDiffEq/8Pn99/src/solve.jl:4 [inlined]
 [26] #__solve at ./none:0 [inlined]
 [27] #solve_call#442(::Bool, ::Base.Iterators.Pairs{Symbol,Any,NTuple{6,Symbol},NamedTuple{(:save_everystep, :save_start, :saveat, :tstops, :abstol, :reltol),Tuple{Bool,Bool,Array{Float32,1},Array{Float32,1},Float64,Float64}}}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5) at /Users/cems/.julia/packages/DiffEqBase/LNYfU/src/solve.jl:44
 [28] (::DiffEqBase.var"#kw##solve")(::NamedTuple{(:save_everystep, :save_start, :saveat, :tstops, :abstol, :reltol),Tuple{Bool,Bool,Array{Float32,1},Array{Float32,1},Float64,Float64}}, ::typeof(solve), ::ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float32,1},Nothing,Nothing,Nothing,Array{Float32,2},Nothing},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,Nothing},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#33#36"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#38"{typeof(DiffEqBase.INITIALIZE_DEFAULT),Bool,DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#37"{DiffEqSensitivity.var"#40#42"{Base.RefValue{Int64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}},DiffEqSensitivity.var"#41#43"{DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}},Bool,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float32,2},Base.RefValue{Int64},Int64},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5) at ./none:0
 [29] #_adjoint_sensitivities#13(::Float64, ::Float64, ::Array{Float32,1}, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(DiffEqSensitivity._adjoint_sensitivities), ::ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}, ::DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}}, ::Tsit5, ::DiffEqSensitivity.var"#df#60"{Array{Float32,3},Array{Float32,2}}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing) at /Users/cems/.julia/packages/DiffEqSensitivity/ZX2U1/src/local_sensitivity/sensitivity_interface.jl:16
 [30] _adjoint_sensitivities(::ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}, ::DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}}, ::Tsit5, ::Function, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing) at /Users/cems/.julia/packages/DiffEqSensitivity/ZX2U1/src/local_sensitivity/sensitivity_interface.jl:13 (repeats 2 times)
 [31] #adjoint_sensitivities#12(::DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}}, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(DiffEqSensitivity.adjoint_sensitivities), ::ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any,N} where N) at /Users/cems/.julia/packages/DiffEqSensitivity/ZX2U1/src/local_sensitivity/sensitivity_interface.jl:6
 [32] (::DiffEqSensitivity.var"#kw##adjoint_sensitivities")(::NamedTuple{(:sensealg,),Tuple{DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}}}}, ::typeof(DiffEqSensitivity.adjoint_sensitivities), ::ODESolution{Float32,3,Array{Array{Float32,2},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,2},1},1},ODEProblem{Array{Float32,2},Tuple{Float32,Float32},false,Array{Float32,1},ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#f#36"{FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,2},1},Array{Float32,1},Array{Array{Array{Float32,2},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any,N} where N) at ./none:0
 [33] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#59"{Tsit5,DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central}},Array{Float32,2},Array{Float32,1},Tuple{}})(::Array{Float32,3}) at /Users/cems/.julia/packages/DiffEqSensitivity/ZX2U1/src/local_sensitivity/concrete_solve.jl:67
 [34] #554#back at /Users/cems/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [35] pred at /Users/cems/Documents/rusty science fair 2019/Zygote_batch_MWE_sub1_error.jl:25 [inlined]
 [36] (::typeof(∂(λ)))(::Array{Float32,3}) at /Users/cems/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [37] loss_vec_batch at /Users/cems/Documents/rusty science fair 2019/Zygote_batch_MWE_sub1_error.jl:33 [inlined]
 [38] (::typeof(∂(λ)))(::Float32) at /Users/cems/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [39] (::Zygote.var"#165#166"{typeof(∂(λ)),Tuple{Tuple{Nothing},Tuple{Nothing}}})(::Float32) at /Users/cems/.julia/packages/Zygote/tJj2w/src/lib/lib.jl:156
 [40] #321#back at /Users/cems/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [41] loss_vec at /Users/cems/Documents/rusty science fair 2019/Zygote_batch_MWE_sub1_error.jl:34 [inlined]
 [42] (::typeof(∂(λ)))(::Float32) at /Users/cems/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [43] #35 at /Users/cems/Documents/rusty science fair 2019/Zygote_batch_MWE_sub1_error.jl:45 [inlined]
 [44] (::typeof(∂(λ)))(::Float32) at /Users/cems/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [45] (::Zygote.var"#46#47"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /Users/cems/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:101
 [46] gradient(::Function, ::Zygote.Params) at /Users/cems/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:47
 [47] do_it() at /Users/cems/Documents/rusty science fair 2019/Zygote_batch_MWE_sub1_error.jl:44
 [48] top-level scope at none:0

Versions


Status `~/.julia/environments/v1.3/Project.toml`
[c52e3926] Atom v0.12.3
[aae7a2af] DiffEqFlux v1.3.2 #master (https://github.com/JuliaDiffEq/DiffEqFlux.jl.git)
[0c46a032] DifferentialEquations v6.11.0
[587475ba] Flux v0.10.1
[7073ff75] IJulia v1.21.1
[e5e0dc1b] Juno v0.7.2
[429524aa] Optim v0.20.1
[1dea7af3] OrdinaryDiffEq v5.29.0
[91a5bcdd] Plots v0.29.1
[d330b81b] PyPlot v2.8.2

``
lungd commented 4 years ago

I think you can define an in-place problem function to make this work (rather a workaround to not hit this issue)

using Flux 
using OrdinaryDiffEq
using DiffEqFlux

model1 = FastDense(2,2)
model2 = FastChain(model1)

#initial conditions
u0 = Float32[1.0,1.0]
u1 = Float32[1.5,1.5]
u2 = Float32[2.0,2.0]

#minibatch = [(u0[:,:],)]  # note this is 2x1.  this matters to error
minibatch = [([u0 u1 u2],)]

# test both models
function do_it()
   for model in (model1,model2)
      println("\nmodel ",model)
      p = initial_params(model)

      #set up ODE Problem
      #f(x,p,t) = model(x,p) # error tracebakc juno highlits this in red

      function f(dx,x,p,t)
          dx .= model(x,p)
          nothing
      end

      prob = ODEProblem(f,minibatch[1][1],(0.0f0,1.0f0),p)  # u0 is a placeholder.
      pred(u,p) = concrete_solve(prob,Tsit5(),u,p, saveat=0.01)

      #validate prediction working
      pred(u0,p)
      pred(minibatch[1][1],p)

      # set up loss function
      # setup loss function
      loss_vec_batch(p,mb) =  sum(abs, pred(mb,p).- mb) # vectorized over batch
      loss_vec(p) = loss_vec_batch(p,minibatch[1]...)
      loss = loss_vec # rename

      println(loss(p))

      #validate Loss Function working
      loss_vec_batch(p,minibatch[1]...) == loss_vec(p)

      # set up the gradient parameters
      ps = Flux.params(p)

      #use zygote on vector case
      gsv = Flux.Zygote.gradient(ps) do
         x = loss(p)
         first(x)
       end

       println("\nMODEL RAN CLEANLY.  ",gsv.grads)

   end
end

do_it()
julia> do_it()

model FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}(2, 2, identity, DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}(Flux.glorot_uniform, Flux.zeros, 2, 2))
76.2751

MODEL RAN CLEANLY.  IdDict{Any,Any}(Float32[0.7815552, -0.026053874, -0.60071194, -0.05056247, 0.0, 0.0] => Float32[325.44238, -300.28796, 296.0859, -272.6577, 202.24295, -186.32872])

model FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}((FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}(2, 2, identity, DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}(Flux.glorot_uniform, Flux.zeros, 2, 2)),))
251.10022

MODEL RAN CLEANLY.  IdDict{Any,Any}(Float32[0.41792107, 0.47331968, 0.24224015, -0.25262898, 0.0, 0.0] => Float32[374.25388, 288.6827, 330.43677, 253.09106, 202.13618, 153.87003])
cems2 commented 4 years ago

three questions:

  1. Can you offer a clue to me about why you guessed the in-place form would work?
  2. Out of curiosity how does ODEProblem tell if the function handed in is a 3 or 4 parameter function (return value or inplace result). Introspection? Testing return values? how?
  3. What does this finding tell you about the origin of the error? Why would adding a Chain matter?

And finally thank you for taking a look and finding this work around.

ChrisRackauckas commented 4 years ago

You shouldn't have to transform this to in-place. I'll look into this today.

Out of curiosity how does ODEProblem tell if the function handed in is a 3 or 4 parameter function (return value or inplace result). Introspection? Testing return values? how?

We keep a sorcerer who specializes in dark magic on retainer for whenever someone defines a DEProblem. Though he recently digitized himself into method table inspection code:

https://github.com/JuliaDiffEq/DiffEqBase.jl/blob/master/src/utils.jl#L1-L43

ChrisRackauckas commented 4 years ago

This problem is the same as https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/160

cems2 commented 4 years ago

it is? they both do involve Zygote.gradient but in the case of 171 it seems like somehow FastChain() is an important factor.

On Mar 3, 2020, at 10:01 AM, Christopher Rackauckas notifications@github.com wrote:

This problem is the same as #160 https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/160 — You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/171?email_source=notifications&email_token=ACRAR7WM7QTTFKWBVSPEHQTRFUZX5A5CNFSM4KYRKPOKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOENUJUIQ#issuecomment-594057762, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACRAR7XZM4VXSEIXWOPTWTTRFUZX5ANCNFSM4KYRKPOA.

ChrisRackauckas commented 4 years ago

Because of how the output is handled.

gabrevaya commented 4 years ago

I think that the error from this code is also caused by the same issue.

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, DiffEqSensitivity
using Flux, Flux.Data.MNIST
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated

# testing on MNIST
# Classify MNIST digits
imgs = MNIST.images()
# Stack images into one large batch
X = hcat(float.(reshape.(imgs, :))...)
X = convert(Array{Float32,2}, X)

labels = MNIST.labels()
# One-hot-encode the labels
Y = onehotbatch(labels, 0:9)

import NNlib.softmax
softmax(x::AbstractArray, p::AbstractArray) = softmax(x)
N = 32

model = FastChain(FastDense(28^2,N,tanh),
                FastDense(N,N,tanh),
                FastDense(N,10),
                softmax)

p0=initial_params(model)
model(X[:,1:100], p0)

dataset = repeated((X, Y), 5)

function loss(p,x,y)
    pred = model(x,p)
    l = crossentropy(pred, y)
    l, pred
end

for mb in dataset
    display(loss(p0,mb...)[1])
end

cb = function (p,l,pred)
  display(l)
  return false
end

res1 = DiffEqFlux.sciml_train(loss, p0, ADAM(), dataset, cb = cb, maxiters = 100)

Which throws the output:

ERROR: ArgumentError: number of columns of each array must match (got (1, 60000))
Stacktrace:
 [1] _typed_vcat(::Type{Float32}, ::Tuple{Array{Float32,1},Array{Float32,2}}) at ./abstractarray.jl:1359
 [2] typed_vcat at ./abstractarray.jl:1373 [inlined]
 [3] vcat at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/SparseArrays/src/sparsevector.jl:1079 [inlined]
 [4] (::DiffEqFlux.var"#FastDense_adjoint#49"{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}},Array{Float32,2},Array{Float32,2},Array{Float32,2},Array{Float32,2}})(::Array{Float32,2}) at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/fast_layers.jl:53
 [5] #167#back at /Users/ger/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [6] applychain at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/fast_layers.jl:20 [inlined]
 [7] (::typeof(∂(applychain)))(::Array{Float32,2}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [8] applychain at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/fast_layers.jl:20 [inlined]
 [9] (::typeof(∂(applychain)))(::Array{Float32,2}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [10] applychain at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/fast_layers.jl:20 [inlined]
 [11] (::typeof(∂(applychain)))(::Array{Float32,2}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [12] FastChain at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/fast_layers.jl:21 [inlined]
 [13] (::typeof(∂(λ)))(::Array{Float32,2}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [14] loss at /Users/ger/Documents/Issues/for_posting_minibatch.jl:32 [inlined]
 [15] (::typeof(∂(loss)))(::Tuple{Float32,Nothing}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [16] #165 at /Users/ger/.julia/packages/Zygote/tJj2w/src/lib/lib.jl:156 [inlined]
 [17] (::Zygote.var"#321#back#167"{Zygote.var"#165#166"{typeof(∂(loss)),Tuple{Tuple{Nothing},Tuple{Nothing,Nothing}}}})(::Tuple{Float32,Nothing}) at /Users/ger/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [18] #18 at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/train.jl:50 [inlined]
 [19] (::typeof(∂(λ)))(::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [20] (::Zygote.var"#46#47"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:101
 [21] gradient(::Function, ::Zygote.Params) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:47
 [22] macro expansion at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/train.jl:49 [inlined]
 [23] macro expansion at /Users/ger/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined]
 [24] #sciml_train#16(::Function, ::Int64, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::ADAM, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{Array{Float32,2},Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}}}}) at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/train.jl:48
 [25] (::DiffEqFlux.var"#kw##sciml_train")(::NamedTuple{(:cb, :maxiters),Tuple{var"#15#16",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::ADAM, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{Array{Float32,2},Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}}}}) at ./none:0
 [26] top-level scope at none:0

I tried also with the minibatch way suggested by #167 and it throws the same error.

I hope this different example (although maybe not MWE 😅) could be somewhat useful.

ChrisRackauckas commented 4 years ago

Closing duplicate issue. It all is boiling down to that DiffEqArray adjoint.

ChrisRackauckas commented 4 years ago

That issue was fixed with a fix to the adjoints.

Minibatching is being tracked a few other places as a tutorial update.