ODINN-SciML / ODINN.jl

Global glacier model using Universal Differential Equations for climate-glacier interactions
MIT License
71 stars 11 forks source link

Investigate fixing AD issues #151

Open ChrisRackauckas opened 4 months ago

ChrisRackauckas commented 4 months ago

Definitely AutoZygote is required on the outside instead of AutoReverseDiff. Investigating what's going on with Enzyme in ternally

using PyCall
certifi = PyCall.pyimport("certifi")
ENV["SSL_CERT_FILE"] = certifi.where()

using ODINN

working_dir = joinpath(homedir(), "OGGM/ODINN_tests")
MB = true
fast = true
atol = 2.0

params = Parameters(OGGM = OGGMparameters(working_dir=working_dir,
                                              multiprocessing=true),
                        simulation = SimulationParameters(working_dir=working_dir,
                                                        use_MB=MB,
                                                        velocities=true,
                                                        tspan=(2010.0, 2015.0),
                                                        multiprocessing=false,
                                                        workers=5,
                                                        test_mode=true),
                        hyper = Hyperparameters(batch_size=4,
                                                epochs=4,
                                                optimizer=ODINN.ADAM(0.01)),
                        UDE = UDEparameters(target = "A",
                                            sensealg = ODINN.GaussAdjoint(autojacvec=ODINN.EnzymeVJP()))
    )

rgi_ids = ["RGI60-11.03638"]

model = Model(iceflow = SIA2Dmodel(params),
                mass_balance = mass_balance = TImodel1(params; DDF=6.0/1000.0, acc_factor=1.2/1000.0),
                machine_learning = NN(params))
glaciers = initialize_glaciers(rgi_ids, params)
functional_inversion = FunctionalInversion(model, glaciers, params)

@time run!(functional_inversion)
ChrisRackauckas commented 4 months ago

https://github.com/SciML/SciMLSensitivity.jl/pull/1060 is required to use GaussAdjoint for this.

ChrisRackauckas commented 4 months ago

Requires https://github.com/ODINN-SciML/Sleipnir.jl/pull/55, https://github.com/ODINN-SciML/Huginn.jl/pull/53, and Enzyme#main (which in turn requires https://github.com/SciML/OptimizationBase.jl/pull/53). With that the callback differentiates, but it hits:

ERROR: AssertionError: Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(/), Tuple{Matrix{Float64}, Float64}} has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information

which is an Enzyme bug in broadcast against a scalar. So we're close now,

ChrisRackauckas commented 4 months ago

Upstreaming issues found here:

The last issue is the real showstopper for now, it's not clear to me that there is a good workaround to that and it seems to me to just be a straight Enzyme bug to fix. The two look a bit bespoke to PyCall and TimerOutputs, the latter is easy to workaround, the former is weird because the workaround currently doesn't work seemingly due to some internals of PyCall.

JordiBolibar commented 4 months ago

OK great, thanks for pushing this through @ChrisRackauckas! 👍🏻 Any idea how far is the solution to the last Enzyme issue?

What is the issue with PyCall? As far as I understood, I thought we were bypassing the differentiation of all Python code for now which is present in the callback.

wsmoses commented 4 months ago

As discussed in the thread, that issue really isn't an Enzyme issue in its own right, but rather that a type unstability exists around the broadcast which makes it illegal for certain code to be represented properly in Julia (if the type instability weren't there Enzyme can write the derivative update totally fine).

recopying the backtrace here, since I think the better solution would just be fixing the type instability.

ERROR: AssertionError: Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(/), Tuple{Matrix{Float64}, Float64}} has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information
Stacktrace:
  [1] active_reg
    @ ~/.julia/packages/Enzyme/F71IJ/src/compiler.jl:525 [inlined]
  [2] active_reg
    @ ~/.julia/packages/Enzyme/F71IJ/src/compiler.jl:516 [inlined]
  [3] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Base.Broadcast.materialize), df::Nothing, primal_1::Base.Broadcast.Broadcasted{…}, shadow_1_1::Base.Broadcast.Broadcasted{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/F71IJ/src/rules/jitrules.jl:66
  [4] #SIA2D#9
    @ ~/.julia/dev/Huginn/src/models/iceflow/SIA2D/SIA2D_utils.jl:120
  [5] SIA2D
    @ ~/.julia/dev/Huginn/src/models/iceflow/SIA2D/SIA2D_utils.jl:89 [inlined]
  [6] SIA2D_UDE
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:231 [inlined]
  [7] SIA2D_UDE_closure
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:182 [inlined]
  [8] ODEFunction
    @ ~/.julia/packages/SciMLBase/JUp1I/src/scimlfunctions.jl:2296 [inlined]
  [9] #138
    @ ~/.julia/dev/SciMLSensitivity/src/adjoint_common.jl:450 [inlined]
 [10] diffejulia__138_21210_inner_1wrap
    @ ~/.julia/dev/SciMLSensitivity/src/adjoint_common.jl:0
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/F71IJ/src/compiler.jl:5916 [inlined]
 [12] enzyme_call
    @ ~/.julia/packages/Enzyme/F71IJ/src/compiler.jl:5566 [inlined]
 [13] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/F71IJ/src/compiler.jl:5443 [inlined]
 [14] autodiff
    @ ~/.julia/packages/Enzyme/F71IJ/src/Enzyme.jl:291 [inlined]
 [15] _vecjacobian!(dλ::Vector{…}, y::Matrix{…}, λ::Vector{…}, p::ComponentArrays.ComponentVector{…}, t::Float64, S::SciMLSensitivity.ODEGaussAdjointSensitivityFunction{…}, isautojacvec::SciMLSensitivity.EnzymeVJP, dgrad::Nothing, dy::Nothing, W::Nothing)
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:714
 [16] #vecjacobian!#18
    @ ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:231 [inlined]
 [17] vecjacobian!
    @ ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:228 [inlined]
 [18] (::SciMLSensitivity.ODEGaussAdjointSensitivityFunction{…})(du::Vector{…}, u::Vector{…}, p::ComponentArrays.ComponentVector{…}, t::Float64)
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/gauss_adjoint.jl:102
 [19] ODEFunction
    @ ~/.julia/packages/SciMLBase/JUp1I/src/scimlfunctions.jl:2296 [inlined]
 [20] ode_determine_initdt(u0::Vector{…}, t::Float64, tdir::Float64, dtmax::Float64, abstol::Float64, reltol::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::SciMLBase.ODEProblem{…}, integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/tAI61/src/initdt.jl:53
 [21] auto_dt_reset!(integrator::OrdinaryDiffEq.ODEIntegrator)
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/tAI61/src/integrators/integrator_interface.jl:474 [inlined]
 [22] handle_dt!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:580
 [23] __init(prob::SciMLBase.ODEProblem{…}, alg::OrdinaryDiffEq.RDPK3Sp35{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::SciMLBase.CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:533
 [24] __init(prob::Union{…}, alg::Union{…}, timeseries_init::Any, ts_init::Any, ks_init::Any, recompile::Type{…}) where recompile_flag (repeats 5 times)
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:11 [inlined]
 [25] #__solve#805
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:6 [inlined]
 [26] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:1 [inlined]
 [27] solve_call(_prob::SciMLBase.ODEProblem{…}, args::OrdinaryDiffEq.RDPK3Sp35{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:612
 [28] solve_call
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:569 [inlined]
 [29] #solve_up#53
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:1080 [inlined]
 [30] solve_up
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:1066 [inlined]
 [31] #solve#51
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:1003 [inlined]
 [32] _adjoint_sensitivities(sol::SciMLBase.ODESolution{…}, sensealg::SciMLSensitivity.GaussAdjoint{…}, alg::OrdinaryDiffEq.RDPK3Sp35{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::SciMLBase.CallbackSet{…}, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/gauss_adjoint.jl:540
 [33] _adjoint_sensitivities
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/gauss_adjoint.jl:507 [inlined]
 [34] #adjoint_sensitivities#63
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:383 [inlined]
 [35] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#314"{…})(Δ::SciMLBase.ODESolution{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/concrete_solve.jl:556
 [36] ZBack
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [37] (::Zygote.var"#kw_zpullback#53"{…})(dy::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [38] #291
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [39] (::Zygote.var"#2169#back#293"{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [40] #solve#51
    @ ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:1003 [inlined]
 [41] (::Zygote.Pullback{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [42] #291
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [43] (::Zygote.var"#2169#back#293"{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [44] solve
    @ ~/.julia/packages/DiffEqBase/PBhFc/src/solve.jl:993 [inlined]
 [45] (::Zygote.Pullback{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [46] #simulate_iceflow_UDE!#29
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:187 [inlined]
 [47] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [48] simulate_iceflow_UDE!
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:171 [inlined]
 [49] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [50] batch_iceflow_UDE
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:146 [inlined]
 [51] (::Zygote.Pullback{…})(Δ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [52] #25
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:114 [inlined]
 [53] (::Zygote.Pullback{…})(Δ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [54] #680
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:201 [inlined]
 [55] #235
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:157 [inlined]
 [56] (::Base.var"#1023#1028"{Distributed.var"#235#236"{…}})(r::Base.RefValue{Any}, args::Tuple{Tuple{…}})
    @ Base ./asyncmap.jl:94
Stacktrace:
  [1] (::Base.var"#1033#1035")(x::Task)
    @ Base ./asyncmap.jl:171
  [2] foreach(f::Base.var"#1033#1035", itr::Vector{Any})
    @ Base ./abstractarray.jl:3094
  [3] maptwice(wrapped_f::Function, chnl::Channel{Any}, worker_tasks::Vector{Any}, c::Base.Iterators.Zip{Tuple{…}})
    @ Base ./asyncmap.jl:171
  [4] wrap_n_exec_twice
    @ ./asyncmap.jl:147 [inlined]
  [5] #async_usemap#1018
    @ ./asyncmap.jl:97 [inlined]
  [6] kwcall(::NamedTuple, ::typeof(Base.async_usemap), f::Any, c::Vararg{Any})
    @ Base ./asyncmap.jl:78 [inlined]
  [7] #asyncmap#1017
    @ ./asyncmap.jl:75 [inlined]
  [8] asyncmap
    @ ./asyncmap.jl:74 [inlined]
  [9] pmap(f::Function, p::Distributed.WorkerPool, c::Base.Iterators.Zip{…}; distributed::Bool, batch_size::Int64, on_error::Nothing, retry_delays::Vector{…}, retry_check::Nothing)
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:126
 [10] pmap(f::Function, p::Distributed.WorkerPool, c::Base.Iterators.Zip{Tuple{Vector{Tuple{…}}, Vector{@NamedTuple{…}}}})
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:99
 [11] pmap(f::Function, c::Base.Iterators.Zip{Tuple{Vector{Tuple{…}}, Vector{@NamedTuple{…}}}}; kwargs::@Kwargs{})
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:156
 [12] pmap
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:156 [inlined]
 [13] pmap(f::Function, c1::Vector{Tuple{…}}, c::Vector{@NamedTuple{…}})
    @ Distributed ~/.julia/juliaup/julia-1.10.0+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:157
 [14] (::Zygote.var"#map_back#682"{ODINN.var"#25#26"{…}, 1, Tuple{…}, Tuple{…}, Vector{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:201
 [15] (::Zygote.var"#2861#back#688"{Zygote.var"#map_back#682"{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [16] predict_iceflow!
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:114 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [18] loss_iceflow
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:58 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [20] #22
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:31 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [22] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [23] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [24] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/JUp1I/src/scimlfunctions.jl:3762 [inlined]
 [25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] #291
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [27] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [28] #37
    @ ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:90 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [31] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [32] #39
    @ ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:93 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [34] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [35] gradient(f::Function, args::ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [36] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentArrays.ComponentVector{…}, ::ComponentArrays.ComponentVector{…}, ::Vector{…}, ::Vararg{…})
    @ OptimizationZygoteExt ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:93
 [37] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [38] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [39] (::OptimizationOptimisers.var"#12#13"{OptimizationBase.OptimizationCache{…}, ComponentArrays.ComponentVector{…}})()
    @ OptimizationOptimisers ~/.julia/packages/Optimization/jWtfU/src/utils.jl:29
 [40] with_logstate(f::Function, logstate::Any)
    @ Base.CoreLogging ./logging.jl:515
 [41] with_logger
    @ Base.CoreLogging ./logging.jl:627 [inlined]
 [42] maybe_with_logger(f::OptimizationOptimisers.var"#12#13"{…}, logger::LoggingExtras.TeeLogger{…})
    @ Optimization ~/.julia/packages/Optimization/jWtfU/src/utils.jl:7
 [43] macro expansion
    @ ~/.julia/packages/Optimization/jWtfU/src/utils.jl:28 [inlined]
 [44] __solve(cache::OptimizationBase.OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [45] solve!(cache::OptimizationBase.OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/JUp1I/src/solve.jl:188
 [46] solve(prob::SciMLBase.OptimizationProblem{…}, alg::Optimisers.Adam, args::IterTools.NCycle{…}; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/JUp1I/src/solve.jl:96
 [47] train_UDE!(simulation::FunctionalInversion)
    @ ODINN ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:43
 [48] run!(simulation::FunctionalInversion)
    @ ODINN ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:11
 [49] top-level scope
    @ ./timing.jl:279
 [50] top-level scope
    @ none:1
Some type information was truncated. Use `show(err)` to see complete types.
wsmoses commented 4 months ago

@ChrisRackauckas @JordiBolibar where is the inactive function you have for pycall. Indeed I would've expected a inactive marking to force that to be fine, so if you hvae a MWE of that part I can fix.

We could also teach pycall object construction how to properly deal with allocations in an enzyme rule, but the fact you say there's an inactive marking that's not triggering seems like the bigger issue.

ChrisRackauckas commented 4 months ago

Another error I ran into was Tullio support. @wsmoses is that known?

wsmoses commented 4 months ago

All tullio examples that folks have sent us are functional last we checked. Open an issue?

On Tue, Jun 4, 2024 at 7:26 PM Christopher Rackauckas < @.***> wrote:

Another error I ran into was Tullio support. @wsmoses https://github.com/wsmoses is that known?

— Reply to this email directly, view it on GitHub https://github.com/ODINN-SciML/ODINN.jl/pull/151#issuecomment-2148051958, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXBVOCUQJRSFVLPUBLTZFX2F7AVCNFSM6AAAAABINIVDLSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBYGA2TCOJVHA . You are receiving this because you were mentioned.Message ID: @.***>

JordiBolibar commented 4 months ago

As I mentioned in the Huginn PR, Tullio should be easily avoidable here, since we only use it in the out-of-place version of the function that was needed for Zygote and ReverseDiff. Migrating to the in-place version of the function should avoid that.

ChrisRackauckas commented 3 months ago
Training iceflow UDE...
┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 1
└ @ MLUtils ~/.julia/packages/MLUtils/LmmaQ/src/batchview.jl:95
┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 1
└ @ MLUtils ~/.julia/packages/MLUtils/LmmaQ/src/batchview.jl:95
Before solving ODE problem
ODE problem solved for 1
over here for 1
simulation finished for 1
Batch 1 finished!
All batches finished
Loss computed: 13.506396213311314
┌ Warning: Automatic dt set the starting dt as NaN, causing instability. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/F67Rp/src/solve.jl:591
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/integrator_interface.jl:593
ERROR: DimensionMismatch: array could not be broadcast to match destination
Stacktrace:
  [1] check_broadcast_shape
    @ ./broadcast.jl:579 [inlined]
  [2] check_broadcast_axes
    @ ./broadcast.jl:582 [inlined]
  [3] instantiate
    @ ./broadcast.jl:309 [inlined]
  [4] materialize!
    @ ./broadcast.jl:914 [inlined]
  [5] materialize!
    @ ./broadcast.jl:911 [inlined]
  [6] vec_pjac!(out::ComponentArrays.ComponentVector{…}, λ::Vector{…}, y::Matrix{…}, t::Float64, S::SciMLSensitivity.AdjointSensitivityIntegrand{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:288
  [7] AdjointSensitivityIntegrand
    @ ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:309 [inlined]
  [8] (::SciMLSensitivity.AdjointSensitivityIntegrand{…})(t::Float64)
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:321
  [9] evalrule(f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, a::Float64, b::Float64, x::Vector{…}, w::Vector{…}, gw::Vector{…}, nrm::typeof(LinearAlgebra.norm))
    @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/evalrule.jl:0
 [10] #6
    @ ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:15 [inlined]
 [11] ntuple
    @ ./ntuple.jl:48 [inlined]
 [12] do_quadgk(f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, s::Tuple{…}, n::Int64, atol::Float64, rtol::Float64, maxevals::Int64, nrm::typeof(LinearAlgebra.norm), segbuf::Nothing)
    @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:13
 [13] #50
    @ ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:253 [inlined]
 [14] handle_infinities(workfunc::QuadGK.var"#50#51"{…}, f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, s::Tuple{…})
    @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:145
 [15] quadgk(::SciMLSensitivity.AdjointSensitivityIntegrand{…}, ::Float64, ::Vararg{…}; atol::Float64, rtol::Float64, maxevals::Int64, order::Int64, norm::Function, segbuf::Nothing)
    @ QuadGK ~/.julia/packages/QuadGK/OtnWt/src/adapt.jl:252
 [16] _adjoint_sensitivities(sol::SciMLBase.ODESolution{…}, sensealg::SciMLSensitivity.QuadratureAdjoint{…}, alg::OrdinaryDiffEq.RDPK3Sp35{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, callback::SciMLBase.CallbackSet{…}, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:382
 [17] _adjoint_sensitivities
    @ ~/.julia/dev/SciMLSensitivity/src/quadrature_adjoint.jl:324 [inlined]
 [18] #adjoint_sensitivities#63
    @ ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:383 [inlined]
 [19] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#314"{…})(Δ::SciMLBase.ODESolution{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/concrete_solve.jl:556
 [20] ZBack
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [21] (::Zygote.var"#kw_zpullback#53"{…})(dy::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [22] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [23] (::Zygote.var"#2169#back#293"{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [24] #solve#51
    @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1003 [inlined]
 [25] (::Zygote.Pullback{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [27] (::Zygote.var"#2169#back#293"{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [28] solve
    @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:993 [inlined]
 [29] (::Zygote.Pullback{…})(Δ::SciMLBase.ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] #simulate_iceflow_UDE!#32
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:187 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [32] simulate_iceflow_UDE!
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:171 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [34] batch_iceflow_UDE
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:146 [inlined]
 [35] (::Zygote.Pullback{…})(Δ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [36] #28
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:114 [inlined]
 [37] (::Zygote.Pullback{…})(Δ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [38] #680
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:201 [inlined]
 [39] #235
    @ ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:157 [inlined]
 [40] (::Base.var"#1023#1028"{Distributed.var"#235#236"{…}})(r::Base.RefValue{Any}, args::Tuple{Tuple{…}})
    @ Base ./asyncmap.jl:94
 [41] (::Base.var"#1039#1040"{Base.var"#1023#1028"{Distributed.var"#235#236"{Zygote.var"#680#685"}}, Channel{Any}, Nothing})()
    @ Base ./asyncmap.jl:228
Stacktrace:
  [1] (::Base.var"#1033#1035")(x::Task)
    @ Base ./asyncmap.jl:171
  [2] foreach(f::Base.var"#1033#1035", itr::Vector{Any})
    @ Base ./abstractarray.jl:3097
  [3] maptwice(wrapped_f::Function, chnl::Channel{Any}, worker_tasks::Vector{Any}, c::Base.Iterators.Zip{Tuple{…}})
    @ Base ./asyncmap.jl:171
  [4] wrap_n_exec_twice
    @ ./asyncmap.jl:147 [inlined]
  [5] #async_usemap#1018
    @ ./asyncmap.jl:97 [inlined]
  [6] async_usemap
    @ ./asyncmap.jl:78 [inlined]
  [7] #asyncmap#1017
    @ ./asyncmap.jl:75 [inlined]
  [8] asyncmap
    @ ./asyncmap.jl:74 [inlined]
  [9] pmap(f::Function, p::Distributed.WorkerPool, c::Base.Iterators.Zip{…}; distributed::Bool, batch_size::Int64, on_error::Nothing, retry_delays::Vector{…}, retry_check::Nothing)
    @ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:126
 [10] pmap(f::Function, p::Distributed.WorkerPool, c::Base.Iterators.Zip{Tuple{Vector{Tuple{…}}, Vector{@NamedTuple{…}}}})
    @ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:99
 [11] pmap(f::Function, c::Base.Iterators.Zip{Tuple{Vector{Tuple{…}}, Vector{@NamedTuple{…}}}}; kwargs::@Kwargs{})
    @ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:156
 [12] pmap
    @ ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:156 [inlined]
 [13] pmap(f::Function, c1::Vector{Tuple{…}}, c::Vector{@NamedTuple{…}})
    @ Distributed ~/.julia/juliaup/julia-1.10.3+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Distributed/src/pmap.jl:157
 [14] (::Zygote.var"#map_back#682"{ODINN.var"#28#29"{…}, 1, Tuple{…}, Tuple{…}, Vector{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:201
 [15] (::Zygote.var"#2861#back#688"{Zygote.var"#map_back#682"{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [16] predict_iceflow!
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:114 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [18] loss_iceflow
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:58 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [20] #25
    @ ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:31 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [22] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [23] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [24] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/sakPO/src/scimlfunctions.jl:3762 [inlined]
 [25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [27] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [28] #37
    @ ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:90 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [31] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [32] #39
    @ ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:93 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [34] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [35] gradient(f::Function, args::ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [36] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentArrays.ComponentVector{…}, ::ComponentArrays.ComponentVector{…}, ::Vector{…}, ::Vararg{…})
    @ OptimizationZygoteExt ~/.julia/dev/OptimizationBase/ext/OptimizationZygoteExt.jl:93
 [37] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [38] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [39] (::OptimizationOptimisers.var"#12#13"{OptimizationBase.OptimizationCache{…}, ComponentArrays.ComponentVector{…}})()
    @ OptimizationOptimisers ~/.julia/packages/Optimization/jWtfU/src/utils.jl:29
 [40] maybe_with_logger(f::OptimizationOptimisers.var"#12#13"{…}, logger::Nothing)
    @ Optimization ~/.julia/packages/Optimization/jWtfU/src/utils.jl:7
 [41] macro expansion
    @ ~/.julia/packages/Optimization/jWtfU/src/utils.jl:28 [inlined]
 [42] __solve(cache::OptimizationBase.OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [43] solve!(cache::OptimizationBase.OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/solve.jl:188
 [44] solve(prob::SciMLBase.OptimizationProblem{…}, alg::Optimisers.Adam, args::IterTools.NCycle{…}; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/solve.jl:96
 [45] train_UDE!(simulation::FunctionalInversion)
    @ ODINN ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:43
 [46] run!(simulation::FunctionalInversion)
    @ ODINN ~/.julia/dev/ODINN/src/simulations/functional_inversions/functional_inversion_utils.jl:11
 [47] macro expansion
    @ ./timing.jl:279 [inlined]
 [48] top-level scope
    @ ~/Desktop/test.jl:76
Some type information was truncated. Use `show(err)` to see complete types.

Hokay the autodiff is working using the SIA2D!, but in the adjoint the dt goes to NaN which crashes it and I'll need to investigate that.

wsmoses commented 1 month ago

oh if the issue is a nan, enzyme (on cpu only presently) has a nan checker which will throw a backtrace the first time a nan is generated for a derivative.

set Enzyme.Compiler.CheckNan[] = true after importing and the instrumentation will be added to later code.

Clearly this needs more docs [PRs welcome ofc!].