SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
870 stars 157 forks source link

Training a UODE and get error Cannot `convert` an object of type Nothing to an object of type Float32 #703

Open 00krishna opened 2 years ago

00krishna commented 2 years ago

Hello. I was training a simple UODE and am encountering this error when running sciml_train(). The error seems to be somewhere in the interface between DiffEqFlux and GalacticOptim. I have the specific error message, and some code to replicate the problem.

Here is the error message text:

ERROR: MethodError: Cannot `convert` an object of type Nothing to an object of type Float32
Closest candidates are:
  convert(::Type{T}, ::Static.StaticFloat64{N}) where {N, T<:AbstractFloat} at ~/.julia/packages/Static/8hh0B/src/float.jl:26
  convert(::Type{T}, ::LLVM.GenericValue, ::LLVM.LLVMType) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/execution.jl:39
  convert(::Type{T}, ::LLVM.ConstantFP) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/core/value/constant.jl:103
  ...
Stacktrace:
  [1] fill!(dest::Vector{Float32}, x::Nothing)
    @ Base ./array.jl:351
  [2] copyto!
    @ ./broadcast.jl:921 [inlined]
  [3] materialize!
    @ ./broadcast.jl:871 [inlined]
  [4] materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broad)
    @ Base.Broadcast ./broadcast.jl:868
  [5] (::GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_uode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}})(::Vector{Float32}, ::Vector{Float32})
    @ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/function/zygote.jl:8
  [6] macro expansion
    @ ~/.julia/packages/GalacticOptim/fow0r/src/solve/flux.jl:27 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/GalacticOptim/fow0r/src/utils.jl:35 [inlined]
  [8] __solve(prob::OptimizationProblem{, opt::ADAM, data::Base.Iterators.Cycle; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, U)
    @ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/solve/flux.jl:25
  [9] #solve#482
    @ ~/.julia/packages/SciMLBase/OHiiA/src/solve.jl:3 [inlined]
 [10] sciml_train(::typeof(loss_uode), ::Vector{Float32}, ::ADAM, ::Nothing; lower_bounds::Nothing, upper_bounds::Nothing, maxiters::Int64, kwargs::Base.Pairs{Symbol, U)
    @ DiffEqFlux ~/.julia/packages/DiffEqFlux/vJuRw/src/train.jl:91
 [11] top-level scope
    @ ~/Dropbox/sandbox/julia_gend_univ/experiments/uode_benchmarking/mwe.jl:71

Here is the code to replicate the problem. I hardcoded some data so that anyone can precisely replicate the issue. The code itself is based on some of the examples in the Universal ODE github repo.

using StatsBase
using Plots 
using DifferentialEquations
using DiffEqFlux
using StatsPlots
using ComponentArrays
using Distributions
using DifferentialEquations.EnsembleAnalysis
using SciMLBase
using DiffEqCallbacks
using GalacticOptim
using NLopt
using NNlib
using Statistics
using Flux
gr();

rbf(x) = exp.(-(x.^2))

function get_nn(l::Integer, nodes::Integer)
        FastChain(
        FastDense(6, nodes, rbf),
        (FastDense(nodes, nodes, rbf) for _ in 1:l)...,
        FastDense(nodes, 6))
end;

U = get_nn(1, 10)

pinit = initial_params(U)

theta = ComponentVector{Float64}(rattr_f1 = 0.042687970552806695, rattr_f2 = -0.13366589419787755, rattr_f3 = -0.005666828159994795, rattr_m1 = -0.001969619987196797, rattr_m2 = -0.057685466652091864, rattr_m3 = 0.03484072549732818, rhire_f1 = -0.02268797055304856, rhire_f2 = 0.14466589419764714, rhire_f3 = 0.01666682815999726, rhire_m1 = 0.021969619987239963, rhire_m2 = 0.07768546665210119, rhire_m3 = -0.01484072549733023, rprom_f1 = -0.12398041712603916, rprom_f2 = 0.05063523685785915, rprom_f3 = 0.0, rprom_m1 = -0.023528128109958992, rprom_m2 = 0.11427635100343723, rprom_m3 = 0.0, growth_rate_linear = 0.01)

raw_data = [3.3064507529453775, 2.411147842897867, 3.8541953698304243, 8.515747060715876, 18.034486297095288, 32.886093862709245, 4.48549960824742, 2.5562275466943074, 4.058708308896289, 11.51123402926337, 18.08620682567805, 33.57792521585569, 6.0468671504292155, 2.6907051307235283, 4.397601585885496, 14.814845644929573, 18.08529694293005, 34.252912163683725, 7.192672397387545, 2.937028837016381, 4.709731483233916, 18.723508048394834, 17.436642346365506, 34.89649452732027, 7.352580291945779, 3.4231308506611935, 4.888737719134746, 23.479663335819453, 16.209855858422877, 35.33931656393878, 7.2331655891210325, 4.131368800956079, 4.925093311266256, 28.673073555019325, 15.298779271098546, 35.543004556821245, 7.369935260789615, 4.7141389130161775, 4.678267058817686, 32.549852535972846, 14.849690091055214, 35.522268257020734, 8.307311635173907, 4.68681368146203, 4.477904053217947, 34.35541746900183, 15.452784944041495, 35.78857036748951, 10.143679063434147, 4.503984620423812, 4.419740291236412, 35.61459125937783, 16.042619087690753, 36.75482632364377, 11.98711821793013, 4.408672346554513, 4.895176737702706, 37.02628984889759, 16.003122755644856, 37.83154651006416, 13.654191894783967, 4.321165531477556, 5.602231066759056, 39.42383172231546, 15.980116848398685, 38.360016616294025, 15.07127136879804, 4.495640140672605, 6.186352487270125, 41.47585581747662, 16.727453775574347, 37.662018081417905, 15.243668538849033, 5.042530500538143, 6.328609512332072, 41.12613313537764, 19.000477841200336, 36.2447375671493, 14.613889931882971, 5.935769188767932, 6.199844065932502, 37.692953653917044, 22.89764644624489, 35.626941568081406, 13.694132960603485, 7.265264704518828, 5.931120003995033, 32.681378551084386, 26.61944037304275, 36.04946436642747, 12.989840749786113, 8.41647817962317, 5.874204204500485, 28.09604869428785, 28.0865626319234, 36.84136730252351, 12.460187701403568, 8.963034340620466, 6.496033181057384, 24.678755668173096, 27.472004139049975, 37.611189827329774, 12.60904088368306, 9.267867186695572, 7.099502466318395, 23.266513884738565, 25.25166618362897, 38.75528565771812, 13.178193080833045, 9.741361564156664, 7.786274544862907, 23.17812389142553, 23.045696601549917, 39.809506965115844, 13.990735014147583, 10.013169718744082, 8.982072355245323, 23.985936218734057, 20.230054255412583, 41.73550295219455, 14.99439640550598, 9.347386530026874, 10.941111084027392, 24.595027358946172, 17.757242325900023, 43.47253111918933, 14.877425968003662, 8.016118599612227, 13.350653962924648, 24.511605087832592, 15.5650983784468, 43.9015285476356, 14.563830542982078, 6.542545706368523, 15.775065661175315, 24.378778838851183, 14.539469120814292, 42.98157944534187, 14.057476431389063, 5.696368873306736, 17.02074365038269, 23.889658063571883, 14.724942331998033, 41.03721512324303, 14.113967621903988, 5.306870615252368, 17.62434605139526, 23.76767610436216, 15.317268849390269, 39.450457715937425, 14.229422124813736, 5.118940992390327, 17.828337152051862, 24.02635906876211, 16.164701004347, 38.93877340118323, 13.655222792709683, 5.07773075279201, 17.35735286039698, 23.928076206731614, 16.565210177074103, 39.28698134784897, 13.319193222138987, 4.665122538459729, 16.671204260330004, 23.8006495075749, 17.059360191354823, 39.78697306031208, 12.456792318264581, 4.308599089715807, 15.665121353876547, 24.123939210471914, 17.9263437802465, 40.263216511865735, 11.711382659973731, 3.6816219282402844, 15.099750324843555, 24.956045460013552, 19.03089773993017, 41.29472719013936, 11.215236446812579, 2.904896534655703, 14.69177798658438, 25.553490886574302, 20.365736659078202, 42.549245707105925]

full_data = reshape(raw_data, (6, 31))

function genduniv_ode_ude!(du, u, p, t, q)

    û = U(u, p) # network prediction 

    du[1] = q.rhire_f1*u[1] - q.rattr_f1*u[1] - q.rprom_f1*u[1] - û[1]
    du[2] = q.rhire_f2*u[2] + q.rprom_f1*u[1] - q.rattr_f2*u[2] - q.rprom_f2*u[2] + û[2] 
    du[3] = q.rhire_f3*u[3] + q.rprom_f2*u[2] - q.rattr_f3*u[3] - û[3]
    du[4] = q.rhire_m1*u[4] - q.rattr_m1*u[4] - q.rprom_m1*u[4] - û[4]
    du[5] = q.rhire_m2*u[5] + q.rprom_m1*u[4] - q.rattr_m2*u[5] - q.rprom_m2*u[5] + û[5]
    du[6] = q.rhire_m3*u[6] + q.rprom_m2*u[5] - q.rattr_m3*u[6] + û[6]
    du
end;

nn_dynamics!(du, u, p, t) = genduniv_ode_ude!(du, u, p, t, theta)

u0 = full_data[:, 1]
tspan = (1.0, 31.0)

prob_nn = ODEProblem(nn_dynamics!, u0, tspan, pinit)

function predict_uode(params)
    remake(prob_nn, p=params)
    Array(solve(prob_nn, Vern7(), saveat = 1.0))
end;

function loss_uode(p)
    pred = predict_uode(p)
    loss = sum(abs2, full_data .- pred) #+ Float32(1e-4)*sum(abs2, p)/length(p) # Just sum of squared error
    return loss, pred
end;

callback_display = function (p, l, pred)
    display(l)
    return false  
end

result_ode_uode = DiffEqFlux.sciml_train(loss_uode, pinit, ADAM(0.1), cb = callback_display, maxiters = 100) 
00krishna commented 2 years ago

I was continuing to work on this problem. I tried to step through the error and thought this might be related to the AD backend. But it seems like even with ForwardDiff, the error still occurs.

result_ode_uode = DiffEqFlux.sciml_train(loss_uode, 
                                            pinit, 
                                            ADAM(0.1), 
                                            adtype=GalacticOptim.AutoZygote(), 
                                            cb = callback_display, 
                                            maxiters = 100)
00krishna commented 2 years ago

Okay, this problems seems to also happen in the demo hudson_bay.jl code after the first BFGS optimization call using the shooting_loss.

Here is the code that I ran, and then the error message. It is identical to the error message I received when running my code in the OP.

## Environment and packages
cd(@__DIR__)
using Pkg; Pkg.activate("."); Pkg.instantiate()

using DifferentialEquations
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra, Optim
using DiffEqFlux 
using Flux
using Plots
gr()
using JLD2, FileIO
using Statistics
using DelimitedFiles
# Set a random seed for reproduceable behaviour
using Random
Random.seed!(5443)

#### NOTE
# Since the recent release of DataDrivenDiffEq v0.6.0 where a complete overhaul of the optimizers took
# place, SR3 has been used. Right now, STLSQ performs better and has been changed.
# Additionally, the behaviour of the optimization has changed slightly. This has been adjusted
# by decreasing the tolerance of the gradient.

svname = "HudsonBay"
## Data Preprocessing
# The data has been taken from https://jmahaffy.sdsu.edu/courses/f00/math122/labs/labj/q3v1.htm
# Originally published in E. P. Odum (1953), Fundamentals of Ecology, Philadelphia, W. B. Saunders
hudson_bay_data = readdlm("hudson_bay_data.dat", '\t', Float32, '\n')
# Measurements of prey and predator
Xₙ = Matrix(transpose(hudson_bay_data[:, 2:3]))
t = hudson_bay_data[:, 1] .- hudson_bay_data[1, 1]
# Normalize the data; since the data domain is strictly positive
# we just need to divide by the maximum
xscale = maximum(Xₙ, dims =2)
Xₙ .= 1f0 ./ xscale .* Xₙ
# Time from 0 -> n
tspan = (t[1], t[end])

# Plot the data
scatter(t, transpose(Xₙ), xlabel = "t", ylabel = "x(t), y(t)")
plot!(t, transpose(Xₙ), xlabel = "t", ylabel = "x(t), y(t)")

## Direct Identification via SINDy + Collocation

# Create the problem using a gaussian kernel for collocation
full_problem = ContinuousDataDrivenProblem(Xₙ, t, DataDrivenDiffEq.GaussianKernel())
# Look at the collocation
plot(full_problem.t, full_problem.X')
plot(full_problem.t, full_problem.DX')

# Create a Basis
@variables u[1:2]

# Generate the basis functions, multivariate polynomials up to deg 5
# and sine
b = [polynomial_basis(u, 5); sin.(u)]
basis = Basis(b, u)

# Create the thresholds which should be used in the search process
λ = Float32.(exp10.(-7:0.1:5))
# Create an optimizer for the SINDy problem
opt = STLSQ(λ)

# Best result so far
full_res = solve(full_problem, basis, opt, maxiter = 10000, progress = true, denoise = true, normalize = true)

println(full_res)
println(result(full_res))

## Define the network
# Gaussian RBF as activation
rbf(x) = exp.(-(x.^2))

# Define the network 2->5->5->5->2
U = FastChain(
    FastDense(2,5,rbf), FastDense(5,5, rbf), FastDense(5,5, tanh), FastDense(5,2)
)

# Get the initial parameters, first two is linear birth / decay of prey and predator
p = [rand(Float32,2); initial_params(U)]

# Define the hybrid model
function ude_dynamics!(du,u, p, t)
    û = U(u, p[3:end]) # Network prediction
    # We assume a linear birth rate for the prey
    du[1] = p[1]*u[1] + û[1]
    # We assume a linear decay rate for the predator
    du[2] = -p[2]*u[2] + û[2]
end

# Define the problem
prob_nn = ODEProblem(ude_dynamics!,Xₙ[:, 1], tspan, p)

## Function to train the network
# Define a predictor
function predict(θ, X = Xₙ[:,1], T = t)
    Array(solve(prob_nn, Vern7(), u0 = X, p=θ,
                tspan = (T[1], T[end]), saveat = T,
                abstol=1e-6, reltol=1e-6,
                sensealg = ForwardDiffSensitivity()
                ))
end

# Define parameters for Multiple Shooting
group_size = 5
continuity_term = 200.0f0

function loss(data, pred)
    return sum(abs2, data - pred)
end

function shooting_loss(p)
    return multiple_shoot(p, Xₙ, t, prob_nn, loss, Vern7(),
                          group_size; continuity_term)
end

function loss(θ)
    X̂ = predict(θ)
    sum(abs2, Xₙ - X̂) / size(Xₙ, 2) + convert(eltype(θ), 1e-3)*sum(abs2, θ[3:end]) ./ length(θ[3:end])
end

# Container to track the losses
losses = Float32[]

# Callback to show the loss during training
callback(θ,args...) = begin
    l = loss(θ) # Equivalent L2 loss
    push!(losses, l)
    if length(losses)%5==0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    false
end

## Training -> First shooting / batching to get a rough estimate

# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
res1 = DiffEqFlux.sciml_train(shooting_loss, p, ADAM(0.1f0), cb=callback, maxiters = 100)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS to achieve partial fit of the data
res2 = DiffEqFlux.sciml_train(shooting_loss, res1.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 500)    <---- THIS IS WHERE THE ERROR IS OCCURRING.

Here is the error message.

ERROR: MethodError: Cannot `convert` an object of type Nothing to an object of type Float32
Closest candidates are:
  convert(::Type{T}, ::Static.StaticFloat64) where T<:AbstractFloat at ~/.julia/packages/Static/R0QTo/src/float.jl:22
  convert(::Type{T}, ::LLVM.GenericValue, ::LLVM.LLVMType) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/execution.jl:39
  convert(::Type{T}, ::LLVM.ConstantFP) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/core/value/constant.jl:103
  ...
Stacktrace:
  [1] fill!(dest::Vector{Float32}, x::Nothing)
    @ Base ./array.jl:351
  [2] copyto!
    @ ./broadcast.jl:921 [inlined]
  [3] materialize!
    @ ./broadcast.jl:871 [inlined]
  [4] materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broad)
    @ Base.Broadcast ./broadcast.jl:868
  [5] (::GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}})(::Vector{Float32}, ::Vector{Float32})
    @ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/function/zygote.jl:8
  [6] (::GalacticOptim.var"#144#152"{OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, typeof(callback), Tuple{Symbol}, NamedTuple{(:cb,), Tuple{typeof(callback)}}}}, GalacticOptim.var"#143#151"{OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, typeof(callback), Tuple{Symbol}, NamedTuple{(:cb,), Tuple{typeof(callback)}}}}, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}}, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(shooting_loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}})(G::Vector{Float32}, θ::Vector{Float32})
    @ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/solve/optim.jl:93
  [7] value_gradient!!(obj::TwiceDifferentiable{, x::Vector{Float32})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/cfJrN/src/interface.jl:82
  [8] value_gradient!(obj::TwiceDifferentiable{, x::Vector{Float32})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/cfJrN/src/interface.jl:69
  [9] value_gradient!(obj::Optim.ManifoldObject, x::Vector{Float32})
    @ Optim ~/.julia/packages/Optim/wFOeG/src/Manifolds.jl:50
 [10] (::LineSearches.var"#ϕdϕ#6"{Optim.ManifoldObjective{TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}}, Vector{Float32}, Vector{Float32}, Vector{Float32}})(α::Float32)
    @ LineSearches ~/.julia/packages/LineSearches/Ki4c5/src/LineSearches.jl:84
 [11] (::LineSearches.HagerZhang{Float64, Base.RefValue{Bool}})(ϕ::Function, ϕdϕ::LineSearches.var"#ϕd, c::Float32, phi_0::Float32, dphi_0::Float32)
    @ LineSearches ~/.julia/packages/LineSearches/Ki4c5/src/hagerzhang.jl:139
 [12] HagerZhang
    @ ~/.julia/packages/LineSearches/Ki4c5/src/hagerzhang.jl:101 [inlined]
 [13] perform_linesearch!(state::Optim.BFGSState{Vect, method::BFGS{LineSearches.In, d::Optim.ManifoldObject)
    @ Optim ~/.julia/packages/Optim/wFOeG/src/utilities/perform_linesearch.jl:59
 [14] update_state!(d::TwiceDifferentiable{, state::Optim.BFGSState{Vect, method::BFGS{LineSearches.In)
    @ Optim ~/.julia/packages/Optim/wFOeG/src/multivariate/solvers/first_order/bfgs.jl:139
 [15] optimize(d::TwiceDifferentiable{, initial_x::Vector{Float32}, method::BFGS{LineSearches.In, options::Optim.Options{Float6, state::Optim.BFGSState{Vect)
    @ Optim ~/.julia/packages/Optim/wFOeG/src/multivariate/optimize/optimize.jl:54
 [16] optimize(d::TwiceDifferentiable{, initial_x::Vector{Float32}, method::BFGS{LineSearches.In, options::Optim.Options{Float6)
    @ Optim ~/.julia/packages/Optim/wFOeG/src/multivariate/optimize/optimize.jl:36
 [17] ___solve(prob::OptimizationProblem{, opt::BFGS{LineSearches.In, data::Base.Iterators.Cycle; cb::Function, maxiters::Int64, maxtime::Nothing, abstol::Nothing, reltol::Nothing, progress::Bool, kwargs::Base.Pairs{Symbol, U)
    @ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/solve/optim.jl:129
 [18] #__solve#141
    @ ~/.julia/packages/GalacticOptim/fow0r/src/solve/optim.jl:49 [inlined]
 [19] #solve#482
    @ ~/.julia/packages/SciMLBase/nbKmA/src/solve.jl:3 [inlined]
 [20] sciml_train(::typeof(shooting_loss, ::Vector{Float32}, ::BFGS{LineSearches.In, ::Nothing; lower_bounds::Nothing, upper_bounds::Nothing, maxiters::Int64, kwargs::Base.Pairs{Symbol, t)
    @ DiffEqFlux ~/.julia/packages/DiffEqFlux/gH716/src/train.jl:91
 [21] top-level scope
    @ ~/Dropbox/sandbox/julia/universal_differential_equations/LotkaVolterra/hudson_bay.jl:146
00krishna commented 2 years ago

@AlCap23 hopefully you can see this reference. Here is the issue that I mentioned to Chris. For some reason the hudson_bay.jl optimization is failing on the first BFGS run. The error message I received above was about Cannot convert an object of type Nohting to an object of type Float32, but I think the real issue is something else. Interestingly, I wondering if it has to do with the ForwardDiffSensitivity() for the sensealg, since both my own code and the hudson_bay code did that. We can chat about it if you like.

00krishna commented 2 years ago

The problem seems to be with the choice of AD backend. I had to manually set the AD backend to AutoForwardDiff and then it worked. When I used the default AutoZygote or even AutoTracker the code failed with the error message above.

Here is the updated code.

using GalacticOptim

adtype = GalacticOptim.AutoForwardDiff()
res1 = DiffEqFlux.sciml_train(shooting_loss, p, ADAM(0.1f0), adtype, cb=callback, maxiters = 100)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS to achieve partial fit of the data
res2 = DiffEqFlux.sciml_train(shooting_loss, res1.minimizer, BFGS(initial_stepnorm=0.005f0), adtype, cb=callback, maxiters = 500)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Full L2-Loss for full prediction
res3 = DiffEqFlux.sciml_train(loss, res2.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

This is a change from the original code. I am waiting to hear back from Chris and Julius to see if they want me to just adjust the UODE repo code, or if they want to dive in deeper to figure out what is the cause of the problem with Zygote and Tracker, ...

ChrisRackauckas commented 2 years ago

What's the reason for mixing Float64 and Float32? What happens if you make it all the same type?

00krishna commented 2 years ago

@ChrisRackauckas , @AlCap23 and I have been looking at this. So far the error seems to be coming from GalacticOptimand potentially these functions:

https://github.com/SciML/SciMLBase.jl/blob/181f0eb96b40e98eff2bae1268248f9e7f5460cc/src/problems/basic_problems.jl#L113

Julius thinks that one of the dispatches is not working as expected.

I can try and see what happens if we explicitly set everything to Float32 or Float64. We are continuing to trace the problem, but I will update the issue as we figure it out.

ChrisRackauckas commented 2 years ago

Okay, I'm travelling and a bit behind so I'm going to archive this assuming @AlCap23 has it under control, but if that's not the case just ping me.

ChrisRackauckas commented 2 years ago

What's the current status here?

00krishna commented 2 years ago

Have have not found the root of the problem yet. We were thinking it has something to do with the sequence with which the AD rules are being applied or maybe like the type that is getting identified for dispatch. I was wondering if perhaps @avik-pal might help us figure this out, as he seems to have encountered a similar error last week. So I figured I would ping him.