EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
442 stars 62 forks source link

'sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP(), checkpointing=true) ' failed #1749

Closed yunan-l closed 1 month ago

yunan-l commented 1 month ago

Hi, when running

struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, D, K} <:
       Lux.AbstractExplicitContainerLayer{(:model,)}
    model::M
    solver::So
    tspan::T
    device::D
    kwargs::K
end

function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), gpu=nothing, kwargs...)
    device = DetermineDevice(gpu=gpu)
    NeuralODE{typeof(model), typeof(solver), typeof(tspan), typeof(device), typeof(kwargs)}(model, solver, tspan, device, kwargs)
end

function (n::NeuralODE)(u0, ps, st, cb, pp)

    function dudt(u, p, t; st=st)
        u_, st = Lux.apply(n.model, u, p, st)
        return u_
    end

    prob = ODEProblem{false}(ODEFunction{false}(dudt), u0, n.tspan, ps)

    sensealg = get(n.kwargs, :sensealg, InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true))

    tsteps = n.tspan[1]:n.tspan[2]

    sol = solve(prob, n.solver, saveat=tsteps, callback = cb, sensealg = sensealg)

    return DeviceArray(n.device, Array(sol)), st
end

function train_neuralode!(model, u0, p, st, cb, pp, loss_func, opt_state, η_schedule; N_epochs=1, verbose=true, compute_initial_error::Bool=true, scheduler_offset::Int=0)

    best_p = copy(p)
    results = (i_epoch = Int[], train_loss=Float32[], learning_rate=Float32[], duration=Float32[], valid_loss=Float32[], test_loss=Float32[], loss_min=[Inf32], i_epoch_min=[1])

    progress = Progress(N_epochs, 1)

    # initial error 
    lowest_train_err = compute_initial_error ? loss_func(model, u0, p, st, cb, pp) : Inf

    for i_epoch in 1:N_epochs

        Optimisers.adjust!(opt_state, η_schedule(i_epoch + scheduler_offset)) 

        epoch_start_time = time()

        losses = zeros(Float32, 1)

        loss_p(p) = loss_func(model, u0, p, st, cb, pp)

        l, gs = Zygote.withgradient(loss_p, p)

        losses = l
        opt_state, p = Optimisers.update(opt_state, p, gs[1])

        train_err = l
        epoch_time = time() - epoch_start_time

        push!(results[:i_epoch], i_epoch)
        push!(results[:train_loss], train_err)
        push!(results[:learning_rate], η_schedule(i_epoch))
        push!(results[:duration], epoch_time)

        if i_epoch % N_epochs == 0
            monitor(model, u0, p, st, cb, pp)
        end

    end
    return model, best_p, st, results

end

using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, Optimisers, Random, Plots, XLSX, DataFrames, SciMLSensitivity, DiffEqCallbacks, Enzyme, CUDA, LuxCUDA, LuxDeviceUtils
Enzyme.API.runtimeActivity!(true)
CUDA.allowscalar(false)

nn = Chain(
          Dense(4, 16, tanh),  
          Dense(16, 16, tanh), 
          Dense(16, 4)

rng = Xoshiro(0)
p, st = Lux.setup(rng, nn)
p = ComponentArray(p) |> gdev
st = st |> gdev

u0 = Float32[0.0, 8.0, 0.0, 12.0] |> gdev
tspan = (0.0f0, 365.0f0) 

neural_ode = NeuralODE(nn; solver=Tsit5(), tspan = tspan, sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP(), checkpointing=true))

loss = loss_neuralode

opt = Optimisers.AdamW(1f-3, (9f-1, 9.99f-1), 1f-6)
opt_state = Optimisers.setup(opt, p)
η_schedule = SinExp(λ0=1f-3,λ1=1f-5,period=20,decay=0.975f0)

println("starting training...")
neural_de, ps, st, results_ad = train_neuralode!(neural_ode, u0, p, st, cb, pp, loss, opt_state, η_schedule; N_epochs=5, verbose=true)

get:

Enzyme execution failed.
Enzyme: unhandled augmented forward for jl_f_finalizer
Stacktrace:
  [1] finalizer
    @ ./gcutils.jl:87
  [2] _
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:83
  [3] CuArray
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:79
  [4] derive
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:799
  [5] unsafe_contiguous_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:319
  [6] unsafe_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:314
  [7] view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:310
  [8] maybeview
    @ ./views.jl:148
  [9] macro expansion
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:0
 [10] _getindex
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:119
 [11] getproperty
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/namedtuple_interface.jl:14
 [12] macro expansion
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0
 [13] applychain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520

Stacktrace:
  [1] finalizer
    @ ./gcutils.jl:87 [inlined]
  [2] _
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:83 [inlined]
  [3] CuArray
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:79 [inlined]
  [4] derive
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:799 [inlined]
  [5] unsafe_contiguous_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:319 [inlined]
  [6] unsafe_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:314 [inlined]
  [7] view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:310 [inlined]
  [8] maybeview
    @ ./views.jl:148 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:0 [inlined]
 [10] _getindex
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:119 [inlined]
 [11] getproperty
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/namedtuple_interface.jl:14 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0 [inlined]
 [13] applychain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520
 [14] Chain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:518 [inlined]
 [15] apply
    @ ~/.julia/packages/LuxCore/yzx6E/src/LuxCore.jl:171 [inlined]
 [16] dudt
    @ ./In[92]:24 [inlined]
 [17] dudt
    @ ./In[92]:20 [inlined]
 [18] ODEFunction
    @ ~/.julia/packages/SciMLBase/Q1klk/src/scimlfunctions.jl:2335 [inlined]
 [19] #138
    @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:490 [inlined]
 [20] diffejulia__138_128700_inner_1wrap
    @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:0
 [21] macro expansion
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:7049 [inlined]
 [22] enzyme_call
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:6658 [inlined]
 [23] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:6535 [inlined]
 [24] autodiff
    @ ~/.julia/packages/Enzyme/XGb4o/src/Enzyme.jl:320 [inlined]
    ...

the whole log please see attached. EnzymeVJP.failed.txt

wsmoses commented 1 month ago

Closed in favor of https://github.com/JuliaGPU/CUDA.jl/issues/2478