SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
871 stars 157 forks source link

MethodError in sequential learning problem #543

Closed gideonsimpson closed 3 years ago

gideonsimpson commented 3 years ago

I was trying to set up a sequential learning problem as follows:

using Flux, DiffEqFlux
using DifferentialEquations
using Random

ndata= 11;
tspan = (0.0f0, 1.0f0);
tdata = range(tspan[1], tspan[2], length = ndata);

a = Float32(2.);
u0 = Float32[1.]

function f!(du, u, p, t)
   du .= -a * u; 
end

ode_prob = ODEProblem(f!, u0, tspan)

n = 10^2;
Random.seed!(100)
u0data = randn(Float32, n);
function ic_func(prob, i, repeat)
    remake(prob, u0=[u0data[i]])
end
ensemble_prob = EnsembleProblem(ode_prob, prob_func = ic_func)
ensemble_soln = solve(ensemble_prob, Tsit5(),trajectories = n,saveat=tdata);
ensemble_sum = EnsembleSummary(ensemble_soln);
ode_data = Array(ensemble_soln);

Random.seed!(200)
f_nn = FastChain(FastDense(1,20,tanh),FastDense(20,1));
p = initial_params(f_nn);
n_ode = NeuralODE(f_nn, tspan, Tsit5(), saveat = tdata);

train_data = zip(u0data, ode_data);

function predict(u0_)
    Array(n_ode(u0_, p)) 
end

function loss(u_, u0_)
    pred = predict(u0_)
    sum(abs2, u_ .- pred)
end

opt=ADAM(0.05)
Flux.train!(loss, Flux.params(p), train_data, opt)

If I check this code, before training everything appears to work (i.e., the loss and predict functions execute without error). But trying to train it produces the error:

MethodError: Cannot `convert` an object of type Matrix{Float32} to an object of type Float32
Closest candidates are:
  convert(::Type{T}, ::VectorizationBase.AbstractSIMD) where T<:Union{Bool, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, VectorizationBase.Bit} at /Users/guardian/.julia/packages/VectorizationBase/4k8ZW/src/base_defs.jl:150
  convert(::Type{T}, ::LLVM.GenericValue, ::LLVM.LLVMType) where T<:AbstractFloat at /Users/guardian/.julia/packages/LLVM/7Q46C/src/execution.jl:39
  convert(::Type{T}, ::LLVM.ConstantFP) where T<:AbstractFloat at /Users/guardian/.julia/packages/LLVM/7Q46C/src/core/value/constant.jl:98
  ...

Stacktrace:
  [1] setproperty!(x::OrdinaryDiffEq.ODEIntegrator{Tsit5, false, Float32, Nothing, Float32, Vector{Float32}, Float32, Float32, Float32, Vector{Float32}, ODESolution{Float32, 1, Vector{Float32}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Float32}}, ODEProblem{Float32, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float32}, Vector{Float32}, Vector{Vector{Float32}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{Float32, Float32, Float32, Float32, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Bool, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryMinHeap{Float32}, DataStructures.BinaryMinHeap{Float32}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Float32, Float32, Nothing, OrdinaryDiffEq.DefaultInit}, f::Symbol, v::Matrix{Float32})
    @ Base ./Base.jl:34
  [2] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5, false, Float32, Nothing, Float32, Vector{Float32}, Float32, Float32, Float32, Vector{Float32}, ODESolution{Float32, 1, Vector{Float32}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Float32}}, ODEProblem{Float32, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float32}, Vector{Float32}, Vector{Vector{Float32}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{Float32, Float32, Float32, Float32, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Bool, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryMinHeap{Float32}, DataStructures.BinaryMinHeap{Float32}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Float32, Float32, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/kAbV7/src/perform_step/low_order_rk_perform_step.jl:565
  [3] __init(prob::ODEProblem{Float32, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::Tsit5, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, qoldinit::Rational{Int64}, fullnormalize::Bool, failfactor::Int64, beta1::Nothing, beta2::Nothing, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:save_noise,), Tuple{Bool}}})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/kAbV7/src/solve.jl:433
  [4] #__solve#403
    @ ~/.julia/packages/OrdinaryDiffEq/kAbV7/src/solve.jl:4 [inlined]
  [5] #solve_call#56
    @ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:61 [inlined]
  [6] solve_up(prob::ODEProblem{Float32, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::Float32, p::Vector{Float32}, args::Tsit5; kwargs::Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:save_noise, :save_start, :save_end), Tuple{Bool, Bool, Bool}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:82
  [7] #solve#57
    @ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:70 [inlined]
  [8] _concrete_solve_adjoint(::ODEProblem{Float32, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5, ::DiffEqSensitivity.InterpolatingAdjoint{0, true, Val{:central}, DiffEqSensitivity.ZygoteVJP, Bool}, ::Float32, ::Vector{Float32}; save_start::Bool, save_end::Bool, saveat::StepRangeLen{Float32, Float64, Float64}, save_idxs::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/LDOtY/src/concrete_solve.jl:71
  [9] #_solve_adjoint#81
    @ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:329 [inlined]
 [10] #adjoint#72
    @ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:293 [inlined]
 [11] _pullback(__context__::Zygote.Context, #unused#::DiffEqBase.var"#solve_up##kw", kw::NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}, 319::typeof(DiffEqBase.solve_up), prob::ODEProblem{Float32, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::DiffEqSensitivity.InterpolatingAdjoint{0, true, Val{:central}, DiffEqSensitivity.ZygoteVJP, Bool}, u0::Float32, p::Vector{Float32}, args::Tsit5)
    @ DiffEqBase ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:63
 [12] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
 [13] adjoint
    @ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]
 [14] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [15] _pullback
    @ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:70 [inlined]
 [16] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#57", ::DiffEqSensitivity.InterpolatingAdjoint{0, true, Val{:central}, DiffEqSensitivity.ZygoteVJP, Bool}, ::Nothing, ::Nothing, ::Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}, ::typeof(solve), ::ODEProblem{Float32, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [17] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
 [18] adjoint
    @ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]
 [19] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [20] _pullback
    @ ~/.julia/packages/DiffEqBase/jhLIm/src/solve.jl:68 [inlined]
 [21] _pullback(::Zygote.Context, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:sensealg, :saveat), Tuple{DiffEqSensitivity.InterpolatingAdjoint{0, true, Val{:central}, DiffEqSensitivity.ZygoteVJP, Bool}, StepRangeLen{Float32, Float64, Float64}}}, ::typeof(solve), ::ODEProblem{Float32, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [22] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
 [23] adjoint
    @ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]
 [24] adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::Function, ::Tuple{NamedTuple{(:sensealg, :saveat), Tuple{DiffEqSensitivity.InterpolatingAdjoint{0, true, Val{:central}, DiffEqSensitivity.ZygoteVJP, Bool}, StepRangeLen{Float32, Float64, Float64}}}, typeof(solve), ODEProblem{Float32, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, DiffEqFlux.var"#dudt_#88"{NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, ::Tuple{Tsit5})
    @ Zygote ./none:0
 [25] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [26] _pullback
    @ ~/.julia/packages/DiffEqFlux/alPQ3/src/neural_de.jl:77 [inlined]
 [27] _pullback(::Zygote.Context, ::NeuralODE{FastChain{Tuple{FastDense{typeof(tanh), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#73"{Vector{Float32}}}}}, Vector{Float32}, Nothing, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Symbol, StepRangeLen{Float32, Float64, Float64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{StepRangeLen{Float32, Float64, Float64}}}}}, ::Float32, ::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [28] _pullback
    @ ./In[21]:2 [inlined]
 [29] _pullback(ctx::Zygote.Context, f::typeof(predict), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [30] _pullback
    @ ./In[21]:6 [inlined]
 [31] _pullback(::Zygote.Context, ::typeof(loss), ::Float32, ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [32] _apply
    @ ./boot.jl:804 [inlined]
 [33] adjoint
    @ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]
 [34] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [35] _pullback
    @ ~/.julia/packages/Flux/6BByF/src/optimise/train.jl:102 [inlined]
 [36] _pullback(::Zygote.Context, ::Flux.Optimise.var"#39#45"{typeof(loss), Tuple{Float32, Float32}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [37] pullback(f::Function, ps::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:247
 [38] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:58
 [39] macro expansion
    @ ~/.julia/packages/Flux/6BByF/src/optimise/train.jl:101 [inlined]
 [40] macro expansion
    @ ~/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [41] train!(loss::Function, ps::Zygote.Params, data::Base.Iterators.Zip{Tuple{Vector{Float32}, Array{Float32, 3}}}, opt::ADAM; cb::Flux.Optimise.var"#40#46")
    @ Flux.Optimise ~/.julia/packages/Flux/6BByF/src/optimise/train.jl:99
 [42] train!(loss::Function, ps::Zygote.Params, data::Base.Iterators.Zip{Tuple{Vector{Float32}, Array{Float32, 3}}}, opt::ADAM)
    @ Flux.Optimise ~/.julia/packages/Flux/6BByF/src/optimise/train.jl:97
 [43] top-level scope
    @ In[22]:2
 [44] eval
    @ ./boot.jl:360 [inlined]
 [45] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1094
ChrisRackauckas commented 3 years ago

The neural ODE struct is not made for scalar equations:

function predict(u0_)
    Array(n_ode([u0_], p))
end
ChrisRackauckas commented 3 years ago

That fixes the issue.