FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

try/catch is not supported when attempting to use `remake` with Zygote #1479

Open jarroyoe opened 11 months ago

jarroyoe commented 11 months ago

I have a code that used to run 6 months ago that is not running anymore. The code

using OrdinaryDiffEq, DiffEqFlux
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays, Lux, Zygote, Random, CUDA, LinearAlgebra
rng = Random.default_rng()

function trainUDEModel(neuralNetwork,knownDynamics,training_data;needed_ps = Float64[],p_true = Float64[])
    pinit, st = Lux.setup(rng,neuralNetwork)
    st = st |> Lux.gpu
    p64 = Float64.(Lux.gpu(ComponentArray(pinit)))
    training_data = Float64.(Lux.gpu(training_data))
    x0 = Float64.(Lux.gpu(training_data[:,1]))

   function ude(du,u,p,t,q)
        knownPred = convert(CuArray,knownDynamics(u,nothing,q))
        nnPred = convert(CuArray,first(neuralNetwork(u,p,st)))

        du .= knownPred .+ nnPred
    end

    # Closure with the known parameter
    nn_dynamics(du,u,p,t) = ude(du,u,p,t,p_true)
    # Define the problem
    prob_nn = ODEProblem(nn_dynamics,x0, (Float64(1),Float64(size(training_data,2))), p64)
    ## Function to train the network
    # Define a predictor
    function predict(p, X = x0)
        _prob = remake(prob_nn, u0 = X, tspan = (Float64(1),Float64(size(training_data,2))), p = p)
        CUDA.@allowscalar convert(CuArray,solve(_prob, AutoTsit5(Rosenbrock23()), saveat = 1.,
                abstol=1e-6, reltol=1e-6
                ))
    end

    lipschitz_regularizer = 0.5
    function loss_function(p)
        W1 = p.layer_1.weight
        W2 = p.layer_2.weight
        lipschitz_constant = spectralRadius(W1)*spectralRadius(W2)

        pred = predict(p)
        loss = sum(abs2,training_data .- pred)/size(training_data,2) + lipschitz_regularizer*lipschitz_constant
        return loss
    end

    losses = Float64[]

    callback = function (p, l)
      push!(losses, l)
    if length(losses)%50==0
          println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
    end

    ## Training
    #callback(pinit, loss_function(pinit)...; doplot=true)

    adtype = Optimization.AutoZygote()
    optf = Optimization.OptimizationFunction((x, p) -> loss_function(x), adtype)
    optprob = Optimization.OptimizationProblem(optf, p64)

    result_neuralode = Optimization.solve(optprob,
                                           ADAM(),
                                           #callback = callback,
                                           maxiters = 300)

    optprob2 = remake(optprob,u0 = result_neuralode.u)
    result_neuralode2 = Optimization.solve(optprob2,
                                            Optim.BFGS(initial_stepnorm=0.01),
                                            #callback=callback,
                                            allow_f_increases = false)

    return result_neuralode2.u
end

function spectralRadius(X,niters=10)
    y = randn!(similar(X, size(X, 2)))
    tmp = X * y
    for i in 1:niters
        tmp = X*y
        tmp = tmp / norm(tmp)
        y = X' * tmp
        y = y / norm(y)
    end
    return norm(X*y)
end

training_data = rand(4,10)
neuralnetwork = Lux.Chain(Lux.Dense(4,5),Lux.Dense(5,4))
knownDynamics(x,p,q)=-q
trainUDEModel(neuralnetwork,knownDynamics,training_data;p_true=1)
println("Done!")

yields

┌ Info: The GPU function is being called but the GPU is not accessible.
│ Defaulting back to the CPU. (No action is required if you want
└ to run on the CPU).
┌ Info: The GPU function is being called but the GPU is not accessible.
│ Defaulting back to the CPU. (No action is required if you want
└ to run on the CPU).
ERROR: LoadError: Compiling Tuple{var"#predict#6"{Vector{Float64}, Matrix{Float64}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, ODEFunction{true, SciMLBase.AutoSpecialize, var"#nn_dynamics#5"{Int64, var"#ude#4"{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, typeof(knownDynamics)}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, Vector{Float64}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
ERROR: LoadError: Compiling Tuple{var"#predict#6"{Vector{Float64}, Matrix{Float64}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, ODEFunction{true, SciMLBase.AutoSpecialize, var"#nn_dynamics#5"{Int64, var"#ude#4"{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, typeof(knownDynamics)}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, Vector{Float64}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/reverse.jl:128
  [3] #Primal#31
    @ ~/.julia/packages/Zygote/oGI57/src/compiler/reverse.jl:227 [inlined]
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/reverse.jl:128
  [3] #Primal#31
    @ ~/.julia/packages/Zygote/oGI57/src/compiler/reverse.jl:227 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/reverse.jl:352
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/emit.jl:101
  [6] #s3181#1581
    @ ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:28 [inlined]
  [7] var"#s3181#1581"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/reverse.jl:352
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/emit.jl:101
  [6] #s3181#1581
    @ ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:28 [inlined]
  [7] var"#s3181#1581"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
  [9] _pullback
    @ ~/NODE_Community_Forecast/test.jl:27 [inlined]
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
  [9] _pullback
    @ ~/NODE_Community_Forecast/test.jl:27 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::var"#predict#6"{Vector{Float64}, Matrix{Float64}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, ODEFunction{true, SciMLBase.AutoSpecialize, var"#nn_dynamics#5"{Int64, var"#ude#4"{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, typeof(knownDynamics)}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [11] _pullback
    @ ~/NODE_Community_Forecast/test.jl:39 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::var"#loss_function#7"{Matrix{Float64}, Float64}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/NODE_Community_Forecast/test.jl:58 [inlined]
 [14] _apply
    @ ./boot.jl:816 [inlined]
 [15] adjoint
    @ ~/.julia/packages/Zygote/oGI57/src/lib/lib.jl:203 [inlined]
 [16] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [17] _pullback
    @ ~/.julia/packages/SciMLBase/l4PVV/src/scimlfunctions.jl:3772 [inlined]
 [18] _pullback(::Zygote.Context{false}, ::OptimizationFunction{true, Optimization.AutoZygote, var"#3#9"{var"#loss_function#7"{Matrix{Float64}, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [19] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [20] adjoint
    @ ~/.julia/packages/Zygote/oGI57/src/lib/lib.jl:203 [inlined]
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [22] _pullback
    @ ~/.julia/packages/Optimization/GEo8L/src/function/zygote.jl:30 [inlined]
 [23] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#3#9"{var"#loss_function#7"{Matrix{Float64}, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [24] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [25] adjoint
    @ ~/.julia/packages/Zygote/oGI57/src/lib/lib.jl:203 [inlined]
 [26] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [27] _pullback
    @ ~/.julia/packages/Optimization/GEo8L/src/function/zygote.jl:34 [inlined]
 [28] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#158#167"{Tuple{}, Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#3#9"{var"#loss_function#7"{Matrix{Float64}, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [29] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface.jl:44
 [30] pullback
    @ ~/.julia/packages/Zygote/oGI57/src/compiler/interface.jl:42 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::var"#predict#6"{Vector{Float64}, Matrix{Float64}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, ODEFunction{true, SciMLBase.AutoSpecialize, var"#nn_dynamics#5"{Int64, var"#ude#4"{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, typeof(knownDynamics)}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
 [31] gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface.jl:96
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [11] _pullback
    @ ~/NODE_Community_Forecast/test.jl:39 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::var"#loss_function#7"{Matrix{Float64}, Float64}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [32] (::Optimization.var"#157#166"{Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#3#9"{var"#loss_function#7"{Matrix{Float64}, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}})(::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Optimization ~/.julia/packages/Optimization/GEo8L/src/function/zygote.jl:32
 [33] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:31 [inlined]
 [34] macro expansion
    @ ~/.julia/packages/Optimization/GEo8L/src/utils.jl:37 [inlined]
 [13] _pullback
    @ ~/NODE_Community_Forecast/test.jl:58 [inlined]
 [14] _apply
    @ ./boot.jl:816 [inlined]
 [15] adjoint
    @ ~/.julia/packages/Zygote/oGI57/src/lib/lib.jl:203 [inlined]
 [16] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [17] _pullback
    @ ~/.julia/packages/SciMLBase/l4PVV/src/scimlfunctions.jl:3772 [inlined]
 [18] _pullback(::Zygote.Context{false}, ::OptimizationFunction{true, Optimization.AutoZygote, var"#3#9"{var"#loss_function#7"{Matrix{Float64}, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [35] __solve(prob::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, var"#3#9"{var"#loss_function#7"{Matrix{Float64}, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::Optimisers.Adam{Float32}, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:30
 [36] #solve#595
    @ ~/.julia/packages/SciMLBase/l4PVV/src/solve.jl:86 [inlined]
 [19] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [20] adjoint
    @ ~/.julia/packages/Zygote/oGI57/src/lib/lib.jl:203 [inlined]
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [22] _pullback
    @ ~/.julia/packages/Optimization/GEo8L/src/function/zygote.jl:30 [inlined]
 [37] trainUDEModel(neuralNetwork::Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, knownDynamics::typeof(knownDynamics), training_data::Matrix{Float64}; needed_ps::Vector{Float64}, p_true::Int64)
    @ Main ~/NODE_Community_Forecast/test.jl:61
 [38] top-level scope
    @ ~/NODE_Community_Forecast/test.jl:91
in expression starting at /home/jarroyoesquivel/NODE_Community_Forecast/test.jl:91
 [23] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#3#9"{var"#loss_function#7"{Matrix{Float64}, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [24] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [25] adjoint
    @ ~/.julia/packages/Zygote/oGI57/src/lib/lib.jl:203 [inlined]
 [26] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [27] _pullback
    @ ~/.julia/packages/Optimization/GEo8L/src/function/zygote.jl:34 [inlined]
 [28] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#158#167"{Tuple{}, Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#3#9"{var"#loss_function#7"{Matrix{Float64}, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface2.jl:0
 [29] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface.jl:44
 [30] pullback
    @ ~/.julia/packages/Zygote/oGI57/src/compiler/interface.jl:42 [inlined]
 [31] gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/oGI57/src/compiler/interface.jl:96
 [32] (::Optimization.var"#157#166"{Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#3#9"{var"#loss_function#7"{Matrix{Float64}, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}})(::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}})
    @ Optimization ~/.julia/packages/Optimization/GEo8L/src/function/zygote.jl:32
 [33] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:31 [inlined]
 [34] macro expansion
    @ ~/.julia/packages/Optimization/GEo8L/src/utils.jl:37 [inlined]
 [35] __solve(prob::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, var"#3#9"{var"#loss_function#7"{Matrix{Float64}, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:25, Axis(weight = ViewAxis(1:20, ShapedAxis((5, 4), NamedTuple())), bias = ViewAxis(21:25, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(26:49, Axis(weight = ViewAxis(1:20, ShapedAxis((4, 5), NamedTuple())), bias = ViewAxis(21:24, ShapedAxis((4, 1), NamedTuple())))))}}}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::Optimisers.Adam{Float32}, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:30
 [36] #solve#595
    @ ~/.julia/packages/SciMLBase/l4PVV/src/solve.jl:86 [inlined]
 [37] trainUDEModel(neuralNetwork::Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, knownDynamics::typeof(knownDynamics), training_data::Matrix{Float64}; needed_ps::Vector{Float64}, p_true::Int64)
    @ Main ~/NODE_Community_Forecast/test.jl:61
 [38] top-level scope
    @ ~/NODE_Community_Forecast/test.jl:91
in expression starting at /home/jarroyoesquivel/NODE_Community_Forecast/test.jl:91

My current Pkg.status():

⌅ [052768ef] CUDA v3.12.2
⌅ [b0b7db55] ComponentArrays v0.13.8
⌅ [aae7a2af] DiffEqFlux v2.4.0
⌃ [b2108857] Lux v0.4.37
⌃ [7f7a1694] Optimization v3.12.1
⌃ [36348300] OptimizationOptimJL v0.1.5
⌃ [42dfb2eb] OptimizationOptimisers v0.1.2
⌅ [1dea7af3] OrdinaryDiffEq v6.49.1
⌃ [e88e6eb3] Zygote v0.6.58
  [37e2e46d] LinearAlgebra
  [9a3f8284] Random
ToucheSir commented 11 months ago

Hmm, remake and the rest of the SciML code there is doing a lot under the hood we don't know about. Have you tried raising this with them first?