SciML / NeuralPDE.jl

Physics-Informed Neural Networks (PINN) Solvers of (Partial) Differential Equations for Scientific Machine Learning (SciML) accelerated simulation
https://docs.sciml.ai/NeuralPDE/stable/
Other
935 stars 199 forks source link

NNODE training fails with `autodiff=true` #725

Open sathvikbhagavan opened 11 months ago

sathvikbhagavan commented 11 months ago

MWE:

Running one of the tests in NNODE_tests.jl,

using Flux
using Random, NeuralPDE
using OrdinaryDiffEq, Statistics
import OptimizationOptimisers

Random.seed!(100)

# Run a solve on scalars
linear = (u, p, t) -> cos(2pi * t)
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
prob = ODEProblem(linear, u0, tspan)
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))
opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95))

This works -

sol = solve(prob, NeuralPDE.NNODE(chain, opt), dt = 1 / 20.0f0, verbose = true,
            abstol = 1.0f-10, maxiters = 200)

This errors out -

sol = solve(prob, NeuralPDE.NNODE(chain, opt; autodiff=true), dt = 1 / 20.0f0, verbose = true,
            abstol = 1.0f-10, maxiters = 200)

Stacktrace:

```julia julia> sol = solve(prob, NeuralPDE.NNODE(chain, opt; autodiff=true), dt = 1 / 20.0f0, verbose = true, abstol = 1.0f-10, maxiters = 200) WARNING: both DomainSets and SciMLBase export "islinear"; uses of it in module NeuralPDE must be qualified WARNING: both DomainSets and SciMLBase export "isconstant"; uses of it in module NeuralPDE must be qualified WARNING: both DomainSets and SciMLBase export "issquare"; uses of it in module NeuralPDE must be qualified ┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`, │ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`). │ typeof(f) = NeuralPDE.var"#150#151"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, Vector{Float32}} └ @ Zygote ~/.julia/packages/Zygote/4rucm/src/lib/forward.jl:150 ERROR: MethodError: Cannot `convert` an object of type Nothing to an object of type Float32 Closest candidates are: convert(::Type{T}, ::Unitful.Gain) where T<:Real @ Unitful ~/.julia/packages/Unitful/PMWWU/src/logarithm.jl:62 convert(::Type{T}, ::Unitful.Level) where T<:Real @ Unitful ~/.julia/packages/Unitful/PMWWU/src/logarithm.jl:22 convert(::Type{T}, ::Unitful.Quantity) where T<:Real @ Unitful ~/.julia/packages/Unitful/PMWWU/src/conversion.jl:139 ... Stacktrace: [1] fill!(dest::Vector{Float32}, x::Nothing) @ Base ./array.jl:347 [2] copyto! @ ./broadcast.jl:934 [inlined] [3] materialize! @ ./broadcast.jl:884 [inlined] [4] materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(identity), Tuple{Base.RefValue{Nothing}}}) @ Base.Broadcast ./broadcast.jl:881 [5] (::OptimizationZygoteExt.var"#20#29"{OptimizationZygoteExt.var"#19#28"{OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}}})(::Vector{Float32}, ::Vector{Float32}) @ OptimizationZygoteExt ~/.julia/packages/Optimization/72eCu/ext/OptimizationZygoteExt.jl:56 [6] macro expansion @ ~/.julia/packages/OptimizationOptimisers/wD0eI/src/OptimizationOptimisers.jl:65 [inlined] [7] macro expansion @ ~/.julia/packages/Optimization/72eCu/src/utils.jl:37 [inlined] [8] __solve(cache::Optimization.OptimizationCache{OptimizationFunction{false, ADTypes.AutoZygote, OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, OptimizationZygoteExt.var"#20#29"{OptimizationZygoteExt.var"#19#28"{OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}}}, OptimizationZygoteExt.var"#23#32"{OptimizationZygoteExt.var"#19#28"{OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}}}, OptimizationZygoteExt.var"#27#36", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}, Nothing, Nothing, Nothing, Nothing, Nothing, Optimisers.Adam{Float64}, Base.Iterators.Cycle{Tuple{Optimization.NullData}}, Bool, NeuralPDE.var"#176#180"{Float32, Bool}}) @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/wD0eI/src/OptimizationOptimisers.jl:63 [9] solve!(cache::Optimization.OptimizationCache{OptimizationFunction{false, ADTypes.AutoZygote, OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, OptimizationZygoteExt.var"#20#29"{OptimizationZygoteExt.var"#19#28"{OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}}}, OptimizationZygoteExt.var"#23#32"{OptimizationZygoteExt.var"#19#28"{OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}}}, OptimizationZygoteExt.var"#27#36", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}, Nothing, Nothing, Nothing, Nothing, Nothing, Optimisers.Adam{Float64}, Base.Iterators.Cycle{Tuple{Optimization.NullData}}, Bool, NeuralPDE.var"#176#180"{Float32, Bool}}) @ SciMLBase ~/.julia/packages/SciMLBase/kTUaf/src/solve.jl:162 [10] solve(::OptimizationProblem{true, OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Optimisers.Adam{Float64}; kwargs::Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol}, NamedTuple{(:callback, :maxiters), Tuple{NeuralPDE.var"#176#180"{Float32, Bool}, Int64}}}) @ SciMLBase ~/.julia/packages/SciMLBase/kTUaf/src/solve.jl:83 [11] __solve(::ODEProblem{Float32, Tuple{Float32, Float32}, false, SciMLBase.NullParameters, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::NNODE{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Optimisers.Adam{Float64}, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing, Nothing}; dt::Float32, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64) @ NeuralPDE ~/NeuralPDE.jl/src/ode_solve.jl:455 [12] __solve @ ~/NeuralPDE.jl/src/ode_solve.jl:356 [inlined] [13] #solve_call#33 @ ~/.julia/packages/DiffEqBase/DEv7n/src/solve.jl:511 [inlined] [14] solve_call @ ~/.julia/packages/DiffEqBase/DEv7n/src/solve.jl:481 [inlined] [15] solve_up(prob::ODEProblem{Float32, Tuple{Float32, Float32}, false, SciMLBase.NullParameters, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::Float32, p::SciMLBase.NullParameters, args::NNODE{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Optimisers.Adam{Float64}, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing, Nothing}; kwargs::Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:dt, :verbose, :abstol, :maxiters), Tuple{Float32, Bool, Float32, Int64}}}) @ DiffEqBase ~/.julia/packages/DiffEqBase/DEv7n/src/solve.jl:972 [16] solve_up @ ~/.julia/packages/DiffEqBase/DEv7n/src/solve.jl:945 [inlined] [17] #solve#39 @ ~/.julia/packages/DiffEqBase/DEv7n/src/solve.jl:882 [inlined] [18] top-level scope @ REPL[13]:1 ```
sathvikbhagavan commented 11 months ago

@ChrisRackauckas is this a known issue?

ChrisRackauckas commented 11 months ago

It wasn't but now it is.

sathvikbhagavan commented 6 months ago

Updated MWE:

using Flux
using Random, NeuralPDE
using OrdinaryDiffEq, Statistics
import OptimizationOptimisers

Random.seed!(100)

# Run a solve on scalars
linear = (u, p, t) -> cos(2pi * t)
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
prob = ODEProblem(linear, u0, tspan)
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))
opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95))

sol = solve(prob, NeuralPDE.NNODE(chain, opt; autodiff=true), dt = 1 / 20.0f0, verbose = true,
            abstol = 1.0f-10, maxiters = 200)

does not error out with the same error. (There is a check which errors out if autodiff is true in https://github.com/SciML/NeuralPDE.jl/pull/783)

Removing the check gives me this:

julia> sol = solve(prob, NeuralPDE.NNODE(chain, opt; autodiff=true), dt = 1 / 20.0f0, verbose = true,
                   abstol = 1.0f-10, maxiters = 200)
┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = NeuralPDE.var"#163#164"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Int64, bias::Int64, σ::Tuple{}}, @NamedTuple{weight::Int64, bias::Int64, σ::Tuple{}}}}}, Float32, Float32, Nothing}, Vector{Float32}}
└ @ Zygote ~/.julia/packages/Zygote/WOy6z/src/lib/forward.jl:150
Current loss is: 121.3777458193083, Iteration: 1
Current loss is: 121.3777458193083, Iteration: 2
Current loss is: 121.3777458193083, Iteration: 3
Current loss is: 121.3777458193083, Iteration: 4
Current loss is: 121.3777458193083, Iteration: 5
Current loss is: 121.3777458193083, Iteration: 6
Current loss is: 121.3777458193083, Iteration: 7
Current loss is: 121.3777458193083, Iteration: 8
Current loss is: 121.3777458193083, Iteration: 9
Current loss is: 121.3777458193083, Iteration: 10
Current loss is: 121.3777458193083, Iteration: 11
Current loss is: 121.3777458193083, Iteration: 12
Current loss is: 121.3777458193083, Iteration: 13
Current loss is: 121.3777458193083, Iteration: 14
Current loss is: 121.3777458193083, Iteration: 15
Current loss is: 121.3777458193083, Iteration: 16
Current loss is: 121.3777458193083, Iteration: 17
Current loss is: 121.3777458193083, Iteration: 18
Current loss is: 121.3777458193083, Iteration: 19
Current loss is: 121.3777458193083, Iteration: 20
Current loss is: 121.3777458193083, Iteration: 21
Current loss is: 121.3777458193083, Iteration: 22
Current loss is: 121.3777458193083, Iteration: 23
Current loss is: 121.3777458193083, Iteration: 24
Current loss is: 121.3777458193083, Iteration: 25
Current loss is: 121.3777458193083, Iteration: 26
Current loss is: 121.3777458193083, Iteration: 27
Current loss is: 121.3777458193083, Iteration: 28
Current loss is: 121.3777458193083, Iteration: 29
Current loss is: 121.3777458193083, Iteration: 30
Current loss is: 121.3777458193083, Iteration: 31
Current loss is: 121.3777458193083, Iteration: 32
Current loss is: 121.3777458193083, Iteration: 33
Current loss is: 121.3777458193083, Iteration: 34
Current loss is: 121.3777458193083, Iteration: 35
Current loss is: 121.3777458193083, Iteration: 36
Current loss is: 121.3777458193083, Iteration: 37
Current loss is: 121.3777458193083, Iteration: 38
Current loss is: 121.3777458193083, Iteration: 39
Current loss is: 121.3777458193083, Iteration: 40
Current loss is: 121.3777458193083, Iteration: 41
Current loss is: 121.3777458193083, Iteration: 42
Current loss is: 121.3777458193083, Iteration: 43
Current loss is: 121.3777458193083, Iteration: 44
Current loss is: 121.3777458193083, Iteration: 45
Current loss is: 121.3777458193083, Iteration: 46
Current loss is: 121.3777458193083, Iteration: 47
Current loss is: 121.3777458193083, Iteration: 48
Current loss is: 121.3777458193083, Iteration: 49
Current loss is: 121.3777458193083, Iteration: 50
Current loss is: 121.3777458193083, Iteration: 51
Current loss is: 121.3777458193083, Iteration: 52
Current loss is: 121.3777458193083, Iteration: 53
Current loss is: 121.3777458193083, Iteration: 54
Current loss is: 121.3777458193083, Iteration: 55
Current loss is: 121.3777458193083, Iteration: 56
Current loss is: 121.3777458193083, Iteration: 57
Current loss is: 121.3777458193083, Iteration: 58
Current loss is: 121.3777458193083, Iteration: 59
Current loss is: 121.3777458193083, Iteration: 60
Current loss is: 121.3777458193083, Iteration: 61
Current loss is: 121.3777458193083, Iteration: 62
Current loss is: 121.3777458193083, Iteration: 63
Current loss is: 121.3777458193083, Iteration: 64
Current loss is: 121.3777458193083, Iteration: 65
Current loss is: 121.3777458193083, Iteration: 66
Current loss is: 121.3777458193083, Iteration: 67
Current loss is: 121.3777458193083, Iteration: 68
Current loss is: 121.3777458193083, Iteration: 69
Current loss is: 121.3777458193083, Iteration: 70
Current loss is: 121.3777458193083, Iteration: 71
Current loss is: 121.3777458193083, Iteration: 72
Current loss is: 121.3777458193083, Iteration: 73
Current loss is: 121.3777458193083, Iteration: 74
Current loss is: 121.3777458193083, Iteration: 75
Current loss is: 121.3777458193083, Iteration: 76
Current loss is: 121.3777458193083, Iteration: 77
Current loss is: 121.3777458193083, Iteration: 78
Current loss is: 121.3777458193083, Iteration: 79
Current loss is: 121.3777458193083, Iteration: 80
Current loss is: 121.3777458193083, Iteration: 81
Current loss is: 121.3777458193083, Iteration: 82
Current loss is: 121.3777458193083, Iteration: 83
Current loss is: 121.3777458193083, Iteration: 84
Current loss is: 121.3777458193083, Iteration: 85
Current loss is: 121.3777458193083, Iteration: 86
Current loss is: 121.3777458193083, Iteration: 87
Current loss is: 121.3777458193083, Iteration: 88
Current loss is: 121.3777458193083, Iteration: 89
Current loss is: 121.3777458193083, Iteration: 90
Current loss is: 121.3777458193083, Iteration: 91
Current loss is: 121.3777458193083, Iteration: 92
Current loss is: 121.3777458193083, Iteration: 93
Current loss is: 121.3777458193083, Iteration: 94
Current loss is: 121.3777458193083, Iteration: 95
Current loss is: 121.3777458193083, Iteration: 96
Current loss is: 121.3777458193083, Iteration: 97
Current loss is: 121.3777458193083, Iteration: 98
Current loss is: 121.3777458193083, Iteration: 99
Current loss is: 121.3777458193083, Iteration: 100
Current loss is: 121.3777458193083, Iteration: 101
Current loss is: 121.3777458193083, Iteration: 102
Current loss is: 121.3777458193083, Iteration: 103
Current loss is: 121.3777458193083, Iteration: 104
Current loss is: 121.3777458193083, Iteration: 105
Current loss is: 121.3777458193083, Iteration: 106
Current loss is: 121.3777458193083, Iteration: 107
Current loss is: 121.3777458193083, Iteration: 108
Current loss is: 121.3777458193083, Iteration: 109
Current loss is: 121.3777458193083, Iteration: 110
Current loss is: 121.3777458193083, Iteration: 111
Current loss is: 121.3777458193083, Iteration: 112
Current loss is: 121.3777458193083, Iteration: 113
Current loss is: 121.3777458193083, Iteration: 114
Current loss is: 121.3777458193083, Iteration: 115
Current loss is: 121.3777458193083, Iteration: 116
Current loss is: 121.3777458193083, Iteration: 117
Current loss is: 121.3777458193083, Iteration: 118
Current loss is: 121.3777458193083, Iteration: 119
Current loss is: 121.3777458193083, Iteration: 120
Current loss is: 121.3777458193083, Iteration: 121
Current loss is: 121.3777458193083, Iteration: 122
Current loss is: 121.3777458193083, Iteration: 123
Current loss is: 121.3777458193083, Iteration: 124
Current loss is: 121.3777458193083, Iteration: 125
Current loss is: 121.3777458193083, Iteration: 126
Current loss is: 121.3777458193083, Iteration: 127
Current loss is: 121.3777458193083, Iteration: 128
Current loss is: 121.3777458193083, Iteration: 129
Current loss is: 121.3777458193083, Iteration: 130
Current loss is: 121.3777458193083, Iteration: 131
Current loss is: 121.3777458193083, Iteration: 132
Current loss is: 121.3777458193083, Iteration: 133
Current loss is: 121.3777458193083, Iteration: 134
Current loss is: 121.3777458193083, Iteration: 135
Current loss is: 121.3777458193083, Iteration: 136
Current loss is: 121.3777458193083, Iteration: 137
Current loss is: 121.3777458193083, Iteration: 138
Current loss is: 121.3777458193083, Iteration: 139
Current loss is: 121.3777458193083, Iteration: 140
Current loss is: 121.3777458193083, Iteration: 141
Current loss is: 121.3777458193083, Iteration: 142
Current loss is: 121.3777458193083, Iteration: 143
Current loss is: 121.3777458193083, Iteration: 144
Current loss is: 121.3777458193083, Iteration: 145
Current loss is: 121.3777458193083, Iteration: 146
Current loss is: 121.3777458193083, Iteration: 147
Current loss is: 121.3777458193083, Iteration: 148
Current loss is: 121.3777458193083, Iteration: 149
Current loss is: 121.3777458193083, Iteration: 150
Current loss is: 121.3777458193083, Iteration: 151
Current loss is: 121.3777458193083, Iteration: 152
Current loss is: 121.3777458193083, Iteration: 153
Current loss is: 121.3777458193083, Iteration: 154
Current loss is: 121.3777458193083, Iteration: 155
Current loss is: 121.3777458193083, Iteration: 156
Current loss is: 121.3777458193083, Iteration: 157
Current loss is: 121.3777458193083, Iteration: 158
Current loss is: 121.3777458193083, Iteration: 159
Current loss is: 121.3777458193083, Iteration: 160
Current loss is: 121.3777458193083, Iteration: 161
Current loss is: 121.3777458193083, Iteration: 162
Current loss is: 121.3777458193083, Iteration: 163
Current loss is: 121.3777458193083, Iteration: 164
Current loss is: 121.3777458193083, Iteration: 165
Current loss is: 121.3777458193083, Iteration: 166
Current loss is: 121.3777458193083, Iteration: 167
Current loss is: 121.3777458193083, Iteration: 168
Current loss is: 121.3777458193083, Iteration: 169
Current loss is: 121.3777458193083, Iteration: 170
Current loss is: 121.3777458193083, Iteration: 171
Current loss is: 121.3777458193083, Iteration: 172
Current loss is: 121.3777458193083, Iteration: 173
Current loss is: 121.3777458193083, Iteration: 174
Current loss is: 121.3777458193083, Iteration: 175
Current loss is: 121.3777458193083, Iteration: 176
Current loss is: 121.3777458193083, Iteration: 177
Current loss is: 121.3777458193083, Iteration: 178
Current loss is: 121.3777458193083, Iteration: 179
Current loss is: 121.3777458193083, Iteration: 180
Current loss is: 121.3777458193083, Iteration: 181
Current loss is: 121.3777458193083, Iteration: 182
Current loss is: 121.3777458193083, Iteration: 183
Current loss is: 121.3777458193083, Iteration: 184
Current loss is: 121.3777458193083, Iteration: 185
Current loss is: 121.3777458193083, Iteration: 186
Current loss is: 121.3777458193083, Iteration: 187
Current loss is: 121.3777458193083, Iteration: 188
Current loss is: 121.3777458193083, Iteration: 189
Current loss is: 121.3777458193083, Iteration: 190
Current loss is: 121.3777458193083, Iteration: 191
Current loss is: 121.3777458193083, Iteration: 192
Current loss is: 121.3777458193083, Iteration: 193
Current loss is: 121.3777458193083, Iteration: 194
Current loss is: 121.3777458193083, Iteration: 195
Current loss is: 121.3777458193083, Iteration: 196
Current loss is: 121.3777458193083, Iteration: 197
Current loss is: 121.3777458193083, Iteration: 198
Current loss is: 121.3777458193083, Iteration: 199
Current loss is: 121.3777458193083, Iteration: 200
Current loss is: 121.3777458193083, Iteration: 201
retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0f0:0.05f0:1.0f0
u: 21-element Vector{Float32}:
  0.0
  0.006315714
  0.011539376
  0.015674114
  0.018725103
  0.020699587
  0.021606745
  0.021457678
  ⋮
 -0.007863174
 -0.015861165
 -0.024744594
 -0.034488622
 -0.04506795
 -0.05645652
 -0.06862798

The loss is constant and the NNODE is not getting trained.

ChrisRackauckas commented 6 months ago

The Flux type conversion drops duals, so that's something to start with removing. Let's start by transforming everything to Lux first, clean up and delete code, then isolate.

sathvikbhagavan commented 5 months ago

As the Flux removing is done - https://github.com/SciML/NeuralPDE.jl/pull/789, I visited this back to see what was happening.

With Optimization@3.20,

julia> sol = solve(prob, NeuralPDE.NNODE(luxchain, opt, autodiff = true), dt = 1 / 20.0f0, verbose = true,
                   abstol = 1.0f-10, maxiters = 200)
┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = NeuralPDE.var"#163#164"{NeuralPDE.ODEPhi{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Float32, Float32, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:10, Axis(weight = ViewAxis(1:5, ShapedAxis((5, 1), NamedTuple())), bias = ViewAxis(6:10, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(11:16, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5), NamedTuple())), bias = ViewAxis(6:6, ShapedAxis((1, 1), NamedTuple())))))}}}}
└ @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/forward.jl:150
Current loss is: 133.45653590504156, Iteration: 1
Current loss is: 133.45653590504156, Iteration: 2
Current loss is: 133.45653590504156, Iteration: 3
Current loss is: 133.45653590504156, Iteration: 4
Current loss is: 133.45653590504156, Iteration: 5
Current loss is: 133.45653590504156, Iteration: 6
Current loss is: 133.45653590504156, Iteration: 7
Current loss is: 133.45653590504156, Iteration: 8
Current loss is: 133.45653590504156, Iteration: 9
Current loss is: 133.45653590504156, Iteration: 10
Current loss is: 133.45653590504156, Iteration: 11
Current loss is: 133.45653590504156, Iteration: 12
Current loss is: 133.45653590504156, Iteration: 13
Current loss is: 133.45653590504156, Iteration: 14
Current loss is: 133.45653590504156, Iteration: 15
Current loss is: 133.45653590504156, Iteration: 16
Current loss is: 133.45653590504156, Iteration: 17
Current loss is: 133.45653590504156, Iteration: 18
Current loss is: 133.45653590504156, Iteration: 19
Current loss is: 133.45653590504156, Iteration: 20
Current loss is: 133.45653590504156, Iteration: 21
Current loss is: 133.45653590504156, Iteration: 22
Current loss is: 133.45653590504156, Iteration: 23
Current loss is: 133.45653590504156, Iteration: 24
Current loss is: 133.45653590504156, Iteration: 25
Current loss is: 133.45653590504156, Iteration: 26
Current loss is: 133.45653590504156, Iteration: 27
Current loss is: 133.45653590504156, Iteration: 28
Current loss is: 133.45653590504156, Iteration: 29
Current loss is: 133.45653590504156, Iteration: 30
Current loss is: 133.45653590504156, Iteration: 31
Current loss is: 133.45653590504156, Iteration: 32
Current loss is: 133.45653590504156, Iteration: 33
Current loss is: 133.45653590504156, Iteration: 34
Current loss is: 133.45653590504156, Iteration: 35
Current loss is: 133.45653590504156, Iteration: 36
Current loss is: 133.45653590504156, Iteration: 37
Current loss is: 133.45653590504156, Iteration: 38
Current loss is: 133.45653590504156, Iteration: 39
Current loss is: 133.45653590504156, Iteration: 40
Current loss is: 133.45653590504156, Iteration: 41
Current loss is: 133.45653590504156, Iteration: 42
Current loss is: 133.45653590504156, Iteration: 43
Current loss is: 133.45653590504156, Iteration: 44
Current loss is: 133.45653590504156, Iteration: 45
Current loss is: 133.45653590504156, Iteration: 46
Current loss is: 133.45653590504156, Iteration: 47
Current loss is: 133.45653590504156, Iteration: 48
Current loss is: 133.45653590504156, Iteration: 49
Current loss is: 133.45653590504156, Iteration: 50
Current loss is: 133.45653590504156, Iteration: 51
Current loss is: 133.45653590504156, Iteration: 52
Current loss is: 133.45653590504156, Iteration: 53
Current loss is: 133.45653590504156, Iteration: 54
Current loss is: 133.45653590504156, Iteration: 55
Current loss is: 133.45653590504156, Iteration: 56
Current loss is: 133.45653590504156, Iteration: 57
Current loss is: 133.45653590504156, Iteration: 58
Current loss is: 133.45653590504156, Iteration: 59
Current loss is: 133.45653590504156, Iteration: 60
Current loss is: 133.45653590504156, Iteration: 61
Current loss is: 133.45653590504156, Iteration: 62
Current loss is: 133.45653590504156, Iteration: 63
Current loss is: 133.45653590504156, Iteration: 64
Current loss is: 133.45653590504156, Iteration: 65
Current loss is: 133.45653590504156, Iteration: 66
Current loss is: 133.45653590504156, Iteration: 67
Current loss is: 133.45653590504156, Iteration: 68
Current loss is: 133.45653590504156, Iteration: 69
Current loss is: 133.45653590504156, Iteration: 70
Current loss is: 133.45653590504156, Iteration: 71
Current loss is: 133.45653590504156, Iteration: 72
Current loss is: 133.45653590504156, Iteration: 73
Current loss is: 133.45653590504156, Iteration: 74
Current loss is: 133.45653590504156, Iteration: 75
Current loss is: 133.45653590504156, Iteration: 76
Current loss is: 133.45653590504156, Iteration: 77
Current loss is: 133.45653590504156, Iteration: 78
Current loss is: 133.45653590504156, Iteration: 79
Current loss is: 133.45653590504156, Iteration: 80
Current loss is: 133.45653590504156, Iteration: 81
Current loss is: 133.45653590504156, Iteration: 82
Current loss is: 133.45653590504156, Iteration: 83
Current loss is: 133.45653590504156, Iteration: 84
Current loss is: 133.45653590504156, Iteration: 85
Current loss is: 133.45653590504156, Iteration: 86
Current loss is: 133.45653590504156, Iteration: 87
Current loss is: 133.45653590504156, Iteration: 88
Current loss is: 133.45653590504156, Iteration: 89
Current loss is: 133.45653590504156, Iteration: 90
Current loss is: 133.45653590504156, Iteration: 91
Current loss is: 133.45653590504156, Iteration: 92
Current loss is: 133.45653590504156, Iteration: 93
Current loss is: 133.45653590504156, Iteration: 94
Current loss is: 133.45653590504156, Iteration: 95
Current loss is: 133.45653590504156, Iteration: 96
Current loss is: 133.45653590504156, Iteration: 97
Current loss is: 133.45653590504156, Iteration: 98
Current loss is: 133.45653590504156, Iteration: 99
Current loss is: 133.45653590504156, Iteration: 100
Current loss is: 133.45653590504156, Iteration: 101
Current loss is: 133.45653590504156, Iteration: 102
Current loss is: 133.45653590504156, Iteration: 103
Current loss is: 133.45653590504156, Iteration: 104
Current loss is: 133.45653590504156, Iteration: 105
Current loss is: 133.45653590504156, Iteration: 106
Current loss is: 133.45653590504156, Iteration: 107
Current loss is: 133.45653590504156, Iteration: 108
Current loss is: 133.45653590504156, Iteration: 109
Current loss is: 133.45653590504156, Iteration: 110
Current loss is: 133.45653590504156, Iteration: 111
Current loss is: 133.45653590504156, Iteration: 112
Current loss is: 133.45653590504156, Iteration: 113
Current loss is: 133.45653590504156, Iteration: 114
Current loss is: 133.45653590504156, Iteration: 115
Current loss is: 133.45653590504156, Iteration: 116
Current loss is: 133.45653590504156, Iteration: 117
Current loss is: 133.45653590504156, Iteration: 118
Current loss is: 133.45653590504156, Iteration: 119
Current loss is: 133.45653590504156, Iteration: 120
Current loss is: 133.45653590504156, Iteration: 121
Current loss is: 133.45653590504156, Iteration: 122
Current loss is: 133.45653590504156, Iteration: 123
Current loss is: 133.45653590504156, Iteration: 124
Current loss is: 133.45653590504156, Iteration: 125
Current loss is: 133.45653590504156, Iteration: 126
Current loss is: 133.45653590504156, Iteration: 127
Current loss is: 133.45653590504156, Iteration: 128
Current loss is: 133.45653590504156, Iteration: 129
Current loss is: 133.45653590504156, Iteration: 130
Current loss is: 133.45653590504156, Iteration: 131
Current loss is: 133.45653590504156, Iteration: 132
Current loss is: 133.45653590504156, Iteration: 133
Current loss is: 133.45653590504156, Iteration: 134
Current loss is: 133.45653590504156, Iteration: 135
Current loss is: 133.45653590504156, Iteration: 136
Current loss is: 133.45653590504156, Iteration: 137
Current loss is: 133.45653590504156, Iteration: 138
Current loss is: 133.45653590504156, Iteration: 139
Current loss is: 133.45653590504156, Iteration: 140
Current loss is: 133.45653590504156, Iteration: 141
Current loss is: 133.45653590504156, Iteration: 142
Current loss is: 133.45653590504156, Iteration: 143
Current loss is: 133.45653590504156, Iteration: 144
Current loss is: 133.45653590504156, Iteration: 145
Current loss is: 133.45653590504156, Iteration: 146
Current loss is: 133.45653590504156, Iteration: 147
Current loss is: 133.45653590504156, Iteration: 148
Current loss is: 133.45653590504156, Iteration: 149
Current loss is: 133.45653590504156, Iteration: 150
Current loss is: 133.45653590504156, Iteration: 151
Current loss is: 133.45653590504156, Iteration: 152
Current loss is: 133.45653590504156, Iteration: 153
Current loss is: 133.45653590504156, Iteration: 154
Current loss is: 133.45653590504156, Iteration: 155
Current loss is: 133.45653590504156, Iteration: 156
Current loss is: 133.45653590504156, Iteration: 157
Current loss is: 133.45653590504156, Iteration: 158
Current loss is: 133.45653590504156, Iteration: 159
Current loss is: 133.45653590504156, Iteration: 160
Current loss is: 133.45653590504156, Iteration: 161
Current loss is: 133.45653590504156, Iteration: 162
Current loss is: 133.45653590504156, Iteration: 163
Current loss is: 133.45653590504156, Iteration: 164
Current loss is: 133.45653590504156, Iteration: 165
Current loss is: 133.45653590504156, Iteration: 166
Current loss is: 133.45653590504156, Iteration: 167
Current loss is: 133.45653590504156, Iteration: 168
Current loss is: 133.45653590504156, Iteration: 169
Current loss is: 133.45653590504156, Iteration: 170
Current loss is: 133.45653590504156, Iteration: 171
Current loss is: 133.45653590504156, Iteration: 172
Current loss is: 133.45653590504156, Iteration: 173
Current loss is: 133.45653590504156, Iteration: 174
Current loss is: 133.45653590504156, Iteration: 175
Current loss is: 133.45653590504156, Iteration: 176
Current loss is: 133.45653590504156, Iteration: 177
Current loss is: 133.45653590504156, Iteration: 178
Current loss is: 133.45653590504156, Iteration: 179
Current loss is: 133.45653590504156, Iteration: 180
Current loss is: 133.45653590504156, Iteration: 181
Current loss is: 133.45653590504156, Iteration: 182
Current loss is: 133.45653590504156, Iteration: 183
Current loss is: 133.45653590504156, Iteration: 184
Current loss is: 133.45653590504156, Iteration: 185
Current loss is: 133.45653590504156, Iteration: 186
Current loss is: 133.45653590504156, Iteration: 187
Current loss is: 133.45653590504156, Iteration: 188
Current loss is: 133.45653590504156, Iteration: 189
Current loss is: 133.45653590504156, Iteration: 190
Current loss is: 133.45653590504156, Iteration: 191
Current loss is: 133.45653590504156, Iteration: 192
Current loss is: 133.45653590504156, Iteration: 193
Current loss is: 133.45653590504156, Iteration: 194
Current loss is: 133.45653590504156, Iteration: 195
Current loss is: 133.45653590504156, Iteration: 196
Current loss is: 133.45653590504156, Iteration: 197
Current loss is: 133.45653590504156, Iteration: 198
Current loss is: 133.45653590504156, Iteration: 199
Current loss is: 133.45653590504156, Iteration: 200
Current loss is: 133.45653590504156, Iteration: 201
retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0f0:0.05f0:1.0f0
u: 21-element Vector{Float32}:
  0.0
 -0.026961738
 -0.054784633
 -0.08346676
 -0.1130049
 -0.14339462
 -0.17463014
 -0.20670454
 -0.23960975
 -0.27333638
 -0.30787417
 -0.3432115
 -0.37933594
 -0.41623402
 -0.4538912
 -0.49229237
 -0.53142136
 -0.5712613
 -0.61179453
 -0.6530033
 -0.6948684

where the loss remains constant.

But with Optimization@3.21, I get an error:

ulia> sol = solve(prob, NeuralPDE.NNODE(luxchain, opt, autodiff = true), dt = 1 / 20.0f0, verbose = true,
                   abstol = 1.0f-10, maxiters = 200)
┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = NeuralPDE.var"#163#164"{NeuralPDE.ODEPhi{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Float32, Float32, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:10, Axis(weight = ViewAxis(1:5, ShapedAxis((5, 1), NamedTuple())), bias = ViewAxis(6:10, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(11:16, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5), NamedTuple())), bias = ViewAxis(6:6, ShapedAxis((1, 1), NamedTuple())))))}}}}
└ @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/forward.jl:150
ERROR: MethodError: no method matching zero(::Type{ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{…}}}})

Closest candidates are:
  zero(::Type{Union{}}, Any...)
   @ Base number.jl:310
  zero(::Type{Dates.Time})
   @ Dates ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Dates/src/types.jl:440
  zero(::Type{Pkg.Resolve.FieldValue})
   @ Pkg ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Pkg/src/Resolve/fieldvalues.jl:38
  ...

Stacktrace:
  [1] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{…}})(::ComponentArrays.ComponentVector{Float32, Vector{…}, Tuple{…}}, ::ComponentArrays.ComponentVector{Float32, Vector{…}, Tuple{…}})
    @ OptimizationZygoteExt ~/.julia/packages/Optimization/79XSq/ext/OptimizationZygoteExt.jl:93
  [2] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Optimization/79XSq/src/utils.jl:41 [inlined]
  [4] __solve(cache::Optimization.OptimizationCache{OptimizationFunction{…}, Optimization.ReInitCache{…}, Nothing, Nothing, Nothing, Nothing, Nothing, Optimisers.Adam, Base.Iterators.Cycle{…}, Bool, NeuralPDE.var"#192#196"{…}})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
  [5] solve!(cache::Optimization.OptimizationCache{OptimizationFunction{…}, Optimization.ReInitCache{…}, Nothing, Nothing, Nothing, Nothing, Nothing, Optimisers.Adam, Base.Iterators.Cycle{…}, Bool, NeuralPDE.var"#192#196"{…}})
    @ SciMLBase ~/.julia/packages/SciMLBase/slQep/src/solve.jl:179
  [6] solve(::OptimizationProblem{true, OptimizationFunction{…}, ComponentArrays.ComponentVector{…}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, @Kwargs{}}, ::Optimisers.Adam; kwargs::@Kwargs{callback::NeuralPDE.var"#192#196"{…}, maxiters::Int64})
    @ SciMLBase ~/.julia/packages/SciMLBase/slQep/src/solve.jl:96
  [7] __solve(::ODEProblem{…}, ::NNODE{…}; dt::Float32, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64, tstops::Nothing)
    @ NeuralPDE ~/NeuralPDE.jl/src/ode_solve.jl:489
  [8] __solve
    @ ~/NeuralPDE.jl/src/ode_solve.jl:373 [inlined]
  [9] solve_call(_prob::ODEProblem{…}, args::NNODE{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:609
 [10] solve_call
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:567 [inlined]
 [11] #solve_up#42
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:1058 [inlined]
 [12] solve_up
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:1044 [inlined]
 [13] #solve#40
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:981 [inlined]
 [14] top-level scope
    @ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.

which is because of https://github.com/SciML/Optimization.jl/pull/679. This traces from the ForwardDiff.jacobian for computing derivatives in the equation.

ChrisRackauckas commented 5 months ago

Instead of using ForwardDiff.jacobian, we could do the dual evaluation directly. It's the same as this:

https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/src/derivative_wrappers.jl#L84-L103

        T = typeof(ForwardDiff.Tag(NeuralPDETag(), eltype(t)))
        tdual = Dual{T, eltype(df), 1}(t, ForwardDiff.Partials((one(typeof(t)),)))
        first.(ForwardDiff.partials.(phi(tdual, θ)))

and add a struct NeuralPDETag end. Doing it like this keeps the math intact and removes the higher level interface, so we just differentiate it directly. Since this definition is completely non-mutating it should just work.

sathvikbhagavan commented 5 months ago

Ok, will try this.