SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
326 stars 70 forks source link

Optimization not working as soon as Dense layer gets replaced with others (ex. RNN) #897

Open mariaade26 opened 10 months ago

mariaade26 commented 10 months ago

Hi everyone, I have a problem with my code, in particulat while training the neural network, I've encountered a significant issue when attempting to replace a dense layer with other types of layers. Specifically, I've noticed that when I tried to introduce different layers, such as scalar layers or recurrent layers, my code started generating errors related to automatic differentiation (AD). These errors made it impossible to optimize the model and caused issues during gradient backpropagation. Despite my efforts to address this problem, it seems that AD is not compatible with the new layers, leading to a standstill in training. I'm sure I'm missing something but I don't uderstand it. I would really appreciate some help. I'm attachin the code, the csvs, and the error. Thank you! Internal gains.csv phi heating tutto.csv phisun totale.csv Testerna totale.csv phi heating modificato.csv

Here's is my code:

using DifferentialEquations, Lux, ComponentArrays, DiffEqFlux, Optimization,
      OptimizationOptimJL, Plots, Random, OptimizationOptimisers, DataInterpolations, CSV, DataFrames, Statistics
data_fn = raw"File csv/Testerna totale.csv"
data_df= CSV.read(data_fn, DataFrame; header = false);
Testerna = convert.(Float32, data_df[!, 1])

phih_fn = raw"File csv/phi heating tutto.csv"
phih_df = CSV.read(phih_fn, DataFrame; header = false);

phih_fn2 = raw"File csv/phi heating modificato.csv"
phih_df2 = CSV.read(phih_fn2, DataFrame; header = false);

sun_fn = raw"File csv/phisun totale.csv"
sun_df = CSV.read(sun_fn, DataFrame; header = false);

int_fn= raw"File csv/Internal gains.csv"
int_df= CSV.read(int_fn, DataFrame; header = false);

phih = convert.(Float32, phih_df[!, 1]);
phisun = convert.(Float32, sun_df[!, 1]);
intg = convert.(Float32, int_df[!, 1]);
phih2= convert.(Float32, phih_df2[!,1]);

phisun= phisun .+ intg;
tspan= (0.0f0, 4320.0f0)
t= range(tspan[1], tspan[2], length=4320)

disturbance = LinearInterpolation(Testerna,t);
#disturbance2 = LinearInterpolation(phih,t);
disturbance2 = LinearInterpolation(phih2,t);
disturbance3 = LinearInterpolation(phisun,t);
disturbance4= LinearInterpolation(intg,t);

# Definisci la tua funzione di input esogeno
function Text(t)
    return disturbance(t)
end

function phi(t)
    return disturbance2(t)
end

function phis(t)
    return disturbance3(t)
end  

function intgain(t)
    return disturbance4(t)
end

function f(du, u, p, t)
    Rv = p[1]
    Ci = p[2]
    du[1] = 1/(Rv*Ci) .* (Text(t) .- u[1]) .+ phi(t)/Ci .+ phis(t)/Ci #.+ intgain(t)/Ci

end 

u0= [19.3]
p= [0.005, 27800.0]

# Risolvi l'equazione differenziale
prob = ODEProblem(f, u0, tspan, p)
#sol = solve(prob, VCABM(), saveat=t, reltol = sqrt(eps(one(Float32))), abstol = eps(one(Float32))),
sol = solve( prob, VCABM();
    saveat = t,  reltol = sqrt(eps(one(Float32))),abstol = eps(one(Float32)),)  
tplot=t

# Visualizza il risultato
plot(sol,title = "Soluzione reale", xlabel="Tempo", ylabel="Temperatura [°C]", legend=false)
Tmedia= mean(Testerna)
Tstd=std(Testerna)

phihmedio= mean(phih)
phihstd= std(phih)

phisunmedio= mean(phisun)
phisunstd= std(phisun)

intgmedio= mean(intg)
intgstd= std(intg)

function Tnorm(t)
    return (Text(t) .- Tmedia) ./ Tstd
end

function phinorm(t)
    return (phi(t) .- phihmedio) ./ phihstd
end

function phisunnorm(t)
    return (phis(t) .- phisunmedio) ./ phisunstd
end

function intgnorm(t)
    return (intgain(t) .- intgmedio) ./ intgstd
end
intgnorm (generic function with 1 method)
rng=Random.default_rng()

#nn_model = Lux.Chain(Lux.Dense(4, 8, σ), Lux.Dense(8, 3), Lux.Dense(3,1))
#nn_model = Lux.Chain( Lux.Dense(4, 8, σ), Lux.Scale(2),  Lux.Dense(8, 3), Lux.Dense(3, 1))
#nn_model= Lux.Chain(Lux.GRUCell(4 => 8), Lux.Dense(8,1))

nn_model = Lux.Chain(
    Lux.RNNCell(4 => 8,σ ),  # Layer RNN con input di dimensione 4 e output di dimensione 8
    Flux.Dense(8, 1)
)
p_model, st = Lux.setup(rng, nn_model)

function dudt(u, p, t)
    global st
   #out, st = nn_model(vcat(u[1], Tnorm(t),phinorm(t), phisunnorm(t), intgnorm(t)), p, st)
      out, st = nn_model(vcat(u[1], Tnorm(t),phinorm(t), phisunnorm(t)), p, st)
   # out, st = nn_model(vcat(u[1], Text(t),phi(t), phis(t)), p, st)
    return out
end

prob = ODEProblem(dudt, u0, tspan, nothing)

function predict_neuralode(p)
    _prob = remake(prob; p = p)
    Array( solve( _prob, VCABM(); saveat = t, reltol = sqrt(eps(one(Float32))),abstol = eps(one(Float32)),),)
end

function loss(p)
    pred = predict_neuralode(p)
    N = length(pred)
   #return sum(abs2.(sol[1,:] .- pred')) 
    return Flux.mae(sol[1,:], pred')
end

callback = function(p,l)
        println(l)
     return false 
    end 

adtype = Optimization.AutoZygote()
pinit= ComponentArray(p_model)
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

@time begin 
res0 = Optimization.solve(optprob, Optimisers.Adam(0.08), callback=callback, maxiters = 300)
end 

pred= predict_neuralode(res0.u)
pred=pred'
plot(t,pred, label= "prediction")
plot!(t, sol[1,:], label= "data", xlabel="Tempo", ylabel="Temperatura [°C]")

and the related error:

" ┌ Warning: ZygoteVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add verbose = false to the solve call) └ @ SciMLSensitivity C:\Users\Michele.julia\packages\SciMLSensitivity\NhfkF\src\concrete_solve.jl:94 MethodError: no method matching (::RNNCell{true, false, typeof(σ), typeof(WeightInitializers.zeros32), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.ones32)})(::Vector{Float64}, ::ComponentVector{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{

(weight_ih = ViewAxis(1:32, ShapedAxis((8, 4), NamedTuple())), weight_hh = ViewAxis(33:96, ShapedAxis((8, 8), NamedTuple())), bias = 97:104)}}}, ::NamedTuple{(:rng,), Tuple{Xoshiro}})

Closest candidates are: (::RNNCell{true})(::Tuple{AbstractMatrix, Tuple{AbstractMatrix}}, ::Any, ::NamedTuple) @ Lux C:\Users\Michele.julia\packages\Lux\AU9zG\src\layers\recurrent.jl:271 (::RNNCell{use_bias, false})(::AbstractMatrix, ::Any, ::NamedTuple) where use_bias @ Lux C:\Users\Michele.julia\packages\Lux\AU9zG\src\layers\recurrent.jl:251 ┌ Warning: ReverseDiffVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add verbose = false to the solve call) └ @ SciMLSensitivity C:\Users\Michele.julia\packages\SciMLSensitivity\NhfkF\src\concrete_solve.jl:111 MethodError: no method matching gradient(::SciMLSensitivity.var"#258#264"{ODEProblem{Vector{Float64}, Tuple{Float32, Float32}, false, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:104, Axis(weight_ih = ViewAxis(1:32, ShapedAxis((8, 4), NamedTuple())), weight_hh = ViewAxis(33:96, ShapedAxis((8, 8), NamedTuple())), bias = 97:104)), layer_2 = 105:104)}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, ::Vector{Float64}, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:104, Axis(weight_ih = ViewAxis(1:32, ShapedAxis((8, 4), NamedTuple())), weight_hh = ViewAxis(33:96, ShapedAxis((8, 8), NamedTuple())), bias = 97:104)), layer_2 = 105:104)}}})

Closest candidates are: gradient(::Any, ::Any) @ ReverseDiff C:\Users\Michele.julia\packages\ReverseDiff\7pHoq\src\api\gradients.jl:21 gradient(::Any, ::Any, ::ReverseDiff.GradientConfig) @ ReverseDiff C:\Users\Michele.julia\packages\ReverseDiff\7pHoq\src\api\gradients.jl:21 ┌ Warning: TrackerVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add verbose = false to the solve call) └ @ SciMLSensitivity C:\Users\Michele.julia\packages\SciMLSensitivity\NhfkF\src\concrete_solve.jl:129 MethodError: no method matching (::RNNCell{true, false, typeof(σ), typeof(WeightInitializers.zeros32), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.ones32)})(::TrackedArray{…,Vector{Float64}}, ::ComponentVector{Float32, TrackedArray{…,Vector{Float32}}, Tuple{Axis{(weight_ih = ViewAxis(1:32, ShapedAxis((8, 4), NamedTuple())), weight_hh = ViewAxis(33:96, ShapedAxis((8, 8), NamedTuple())), bias = 97:104)}}}, ::NamedTuple{(:rng,), Tuple{Xoshiro}})

Closest candidates are: (::RNNCell{true})(::Tuple{AbstractMatrix, Tuple{AbstractMatrix}}, ::Any, ::NamedTuple) @ Lux C:\Users\Michele.julia\packages\Lux\AU9zG\src\layers\recurrent.jl:271 (::RNNCell{use_bias, false})(::AbstractMatrix, ::Any, ::NamedTuple) where use_bias @ Lux C:\Users\Michele.julia\packages\Lux\AU9zG\src\layers\recurrent.jl:251 ┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs └ @ SciMLSensitivity C:\Users\Michele.julia\packages\SciMLSensitivity\NhfkF\src\concrete_solve.jl:139 MethodError: no method matching (::RNNCell{true, false, typeof(σ), typeof(WeightInitializers.zeros32), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.ones32)})(::Vector{Float64}, ::ComponentVector{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(weight_ih = ViewAxis(1:32, ShapedAxis((8, 4), NamedTuple())), weight_hh = ViewAxis(33:96, ShapedAxis((8, 8), NamedTuple())), bias = 97:104)}}}, ::NamedTuple{(:rng,), Tuple{Xoshiro}}) "

prbzrg commented 10 months ago

Using

sensealg = SciMLSensitivity.ForwardDiffSensitivity()

in the first two solve calls may fix this issue.

UPDATE: only if direct call works and the problem is in AD.

ChrisRackauckas commented 10 months ago

Forward mode would likely be faster for a model of this size, yes.

That said, the real issue is that what's being passed doesn't match the interface. Look at the error message:

(::RNNCell{true})(::Tuple{AbstractMatrix, Tuple{AbstractMatrix}}, ::Any, ::NamedTuple)

This is saying what the input of an RNNCell has to be. Now look at what's passed in:

out, st = nn_model(vcat(u[1], Tnorm(t),phinorm(t), phisunnorm(t)), p, st)

It says it wants a Tuple{AbstractMatrix, Tuple{AbstractMatrix}} as the first argument, but clearly the first argument is a vector defined by a vcat, so that's very clearly not following Lux's interface. Look at the documentation here and follow it's input specification:

https://lux.csail.mit.edu/dev/api/Lux/layers#Lux.RNNCell

(@avik-pal)

Optimization not working as soon as Dense layer gets replaced with others (ex. RNN)

So this title doesn't really make sense. Other layers work, there's lots of examples of this, for example https://docs.sciml.ai/DiffEqFlux/dev/examples/mnist_conv_neural_ode/ uses convolutional layers so "oh no nothing other than Dense works" is just false. What is true is that when you change the layer type to something different, the neural network library may require a slightly different input (as is the case with RNNCell) and you need to modify the ODE definition so that it is being called the way that the neural network library expects it to be called. This has nothing to do with SciML though, this would error outside of the ODE even if calling the neural network directly.

mariaade26 commented 10 months ago

Thank you.

avik-pal commented 10 months ago

You need to have a batch dimension for the layer to work.

Also, note that your current code doesn't do what you want to do. It always treats the input as the first input in the sequence, which makes using an RNNCell quite pointless. You might want to look at https://lux.csail.mit.edu/dev/api/Lux/layers#Lux.StatefulRecurrentCell.

That being said, you need to turn off the adaptivity of the solver, else stateful RNNs don't make sense since the time is not monotonically increasing.