Closed yunan-l closed 1 month ago
Hi, when running
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, D, K} <: Lux.AbstractExplicitContainerLayer{(:model,)} model::M solver::So tspan::T device::D kwargs::K end function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), gpu=nothing, kwargs...) device = DetermineDevice(gpu=gpu) NeuralODE{typeof(model), typeof(solver), typeof(tspan), typeof(device), typeof(kwargs)}(model, solver, tspan, device, kwargs) end function (n::NeuralODE)(u0, ps, st, cb, pp) function dudt(u, p, t; st=st) u_, st = Lux.apply(n.model, u, p, st) return u_ end prob = ODEProblem{false}(ODEFunction{false}(dudt), u0, n.tspan, ps) sensealg = get(n.kwargs, :sensealg, InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)) tsteps = n.tspan[1]:n.tspan[2] sol = solve(prob, n.solver, saveat=tsteps, callback = cb, sensealg = sensealg) return DeviceArray(n.device, Array(sol)), st end function train_neuralode!(model, u0, p, st, cb, pp, loss_func, opt_state, η_schedule; N_epochs=1, verbose=true, compute_initial_error::Bool=true, scheduler_offset::Int=0) best_p = copy(p) results = (i_epoch = Int[], train_loss=Float32[], learning_rate=Float32[], duration=Float32[], valid_loss=Float32[], test_loss=Float32[], loss_min=[Inf32], i_epoch_min=[1]) progress = Progress(N_epochs, 1) # initial error lowest_train_err = compute_initial_error ? loss_func(model, u0, p, st, cb, pp) : Inf for i_epoch in 1:N_epochs Optimisers.adjust!(opt_state, η_schedule(i_epoch + scheduler_offset)) epoch_start_time = time() losses = zeros(Float32, 1) loss_p(p) = loss_func(model, u0, p, st, cb, pp) l, gs = Zygote.withgradient(loss_p, p) losses = l opt_state, p = Optimisers.update(opt_state, p, gs[1]) train_err = l epoch_time = time() - epoch_start_time push!(results[:i_epoch], i_epoch) push!(results[:train_loss], train_err) push!(results[:learning_rate], η_schedule(i_epoch)) push!(results[:duration], epoch_time) if i_epoch % N_epochs == 0 monitor(model, u0, p, st, cb, pp) end end return model, best_p, st, results end using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, Optimisers, Random, Plots, XLSX, DataFrames, SciMLSensitivity, DiffEqCallbacks, Enzyme, CUDA, LuxCUDA, LuxDeviceUtils Enzyme.API.runtimeActivity!(true) CUDA.allowscalar(false) nn = Chain( Dense(4, 16, tanh), Dense(16, 16, tanh), Dense(16, 4) rng = Xoshiro(0) p, st = Lux.setup(rng, nn) p = ComponentArray(p) |> gdev st = st |> gdev u0 = Float32[0.0, 8.0, 0.0, 12.0] |> gdev tspan = (0.0f0, 365.0f0) neural_ode = NeuralODE(nn; solver=Tsit5(), tspan = tspan, sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP(), checkpointing=true)) loss = loss_neuralode opt = Optimisers.AdamW(1f-3, (9f-1, 9.99f-1), 1f-6) opt_state = Optimisers.setup(opt, p) η_schedule = SinExp(λ0=1f-3,λ1=1f-5,period=20,decay=0.975f0) println("starting training...") neural_de, ps, st, results_ad = train_neuralode!(neural_ode, u0, p, st, cb, pp, loss, opt_state, η_schedule; N_epochs=5, verbose=true)
get:
Enzyme execution failed. Enzyme: unhandled augmented forward for jl_f_finalizer Stacktrace: [1] finalizer @ ./gcutils.jl:87 [2] _ @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:83 [3] CuArray @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:79 [4] derive @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:799 [5] unsafe_contiguous_view @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:319 [6] unsafe_view @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:314 [7] view @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:310 [8] maybeview @ ./views.jl:148 [9] macro expansion @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:0 [10] _getindex @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:119 [11] getproperty @ ~/.julia/packages/ComponentArrays/xO4hy/src/namedtuple_interface.jl:14 [12] macro expansion @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0 [13] applychain @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520 Stacktrace: [1] finalizer @ ./gcutils.jl:87 [inlined] [2] _ @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:83 [inlined] [3] CuArray @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:79 [inlined] [4] derive @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:799 [inlined] [5] unsafe_contiguous_view @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:319 [inlined] [6] unsafe_view @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:314 [inlined] [7] view @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:310 [inlined] [8] maybeview @ ./views.jl:148 [inlined] [9] macro expansion @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:0 [inlined] [10] _getindex @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:119 [inlined] [11] getproperty @ ~/.julia/packages/ComponentArrays/xO4hy/src/namedtuple_interface.jl:14 [inlined] [12] macro expansion @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0 [inlined] [13] applychain @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520 [14] Chain @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:518 [inlined] [15] apply @ ~/.julia/packages/LuxCore/yzx6E/src/LuxCore.jl:171 [inlined] [16] dudt @ ./In[92]:24 [inlined] [17] dudt @ ./In[92]:20 [inlined] [18] ODEFunction @ ~/.julia/packages/SciMLBase/Q1klk/src/scimlfunctions.jl:2335 [inlined] [19] #138 @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:490 [inlined] [20] diffejulia__138_128700_inner_1wrap @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:0 [21] macro expansion @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:7049 [inlined] [22] enzyme_call @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:6658 [inlined] [23] CombinedAdjointThunk @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:6535 [inlined] [24] autodiff @ ~/.julia/packages/Enzyme/XGb4o/src/Enzyme.jl:320 [inlined] ...
the whole log please see attached. EnzymeVJP.failed.txt
Closed in favor of https://github.com/JuliaGPU/CUDA.jl/issues/2478
Hi, when running
get:
the whole log please see attached. EnzymeVJP.failed.txt