EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
438 stars 62 forks source link

make_zero! can fail on some immutable functions #1661

Open ChrisRackauckas opened 1 month ago

ChrisRackauckas commented 1 month ago

Example from https://github.com/SciML/SciMLSensitivity.jl/pull/1067:

using OrdinaryDiffEq, Zygote, SciMLSensitivity

N0 = [0.0] # initial population
p = [100.0, 50.0] # steady-state pop., M
tspan = (0.0, 10.0) # integration time
f(D, u, p, t) = (D[1] = p[1] - u[1]) # system
prob = ODEProblem(f, N0, tspan, p)

# at time tinject1 we inject M1 cells
tinject = 8.0
condition(u, t, integrator) = t == tinject
affect(integrator) = integrator.u[1] += integrator.p[2]
cb = DiscreteCallback(condition, affect)

function loss(p)
    _prob = remake(prob, p = p)
    _sol = solve(_prob, Tsit5(); callback = cb,
        abstol = 1e-14, reltol = 1e-14, tstops = [tinject],
        sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()))
    _sol.u[end][1]
end

gZy = Zygote.gradient(loss, p)[1]

Throws:

ERROR: setfield!: immutable struct of type #136#140 cannot be changed
Stacktrace:
  [1] make_zero!
    @ ~/.julia/packages/Enzyme/SiyIj/src/compiler.jl:1601 [inlined]
  [2] make_zero!
    @ ~/.julia/packages/Enzyme/SiyIj/src/compiler.jl:1576 [inlined]
  [3] _vecjacobian!(dλ::SubArray{…}, y::Vector{…}, λ::SubArray{…}, p::Vector{…}, t::Float64, S::SciMLSensitivity.CallbackSensitivityFunction{…}, isautojacvec::EnzymeVJP, dgrad::SubArray{…}, dy::SubArray{…}, W::Nothing)
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:710
  [4] #vecjacobian!#18
    @ ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:232 [inlined]
  [5] vecjacobian!
    @ ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:229 [inlined]
  [6] (::SciMLSensitivity.var"#affect!#272"{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/callback_tracking.jl:339
  [7] #111
    @ ~/.julia/packages/DiffEqCallbacks/9fKPq/src/preset_time.jl:58 [inlined]
  [8] apply_discrete_callback!
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/callbacks.jl:613 [inlined]
  [9] apply_discrete_callback! (repeats 2 times)
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/callbacks.jl:635 [inlined]
 [10] apply_discrete_callback!
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/callbacks.jl:628 [inlined]
 [11] handle_callbacks!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/integrators/integrator_utils.jl:349
 [12] _loopfooter!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/integrators/integrator_utils.jl:254
 [13] loopfooter!
    @ ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/integrators/integrator_utils.jl:207 [inlined]
 [14] solve!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/solve.jl:558
 [15] #__solve#560
    @ ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/solve.jl:7 [inlined]
 [16] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/solve.jl:1 [inlined]
 [17] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:612
 [18] solve_call
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:569 [inlined]
 [19] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1080 [inlined]
 [20] solve_up
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1066 [inlined]
 [21] #solve#51
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
 [22] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::BacksolveAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::CallbackSet{…}, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:448
 [23] _adjoint_sensitivities
    @ ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:405 [inlined]
 [24] #adjoint_sensitivities#63
    @ ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:401 [inlined]
 [25] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#310"{…})(Δ::ODESolution{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/concrete_solve.jl:619
 [26] ZBack
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [27] (::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [28] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [29] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [30] #solve#51
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
 [31] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [32] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [33] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [34] solve
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
 [35] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [36] loss
    @ ~/Desktop/test.jl:84 [inlined]
 [37] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [38] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [39] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [40] top-level scope
    @ ~/Desktop/test.jl:91
Some type information was truncated. Use `show(err)` to see complete types.

But I haven't been able to isolate it any more.

wsmoses commented 1 month ago

Unfortunately I don't think this is an error we can resolve here (depending on the type).

You can't update an immutable type so doing an in place update doesn't make sense

ChrisRackauckas commented 1 month ago

I would assume the behavior of make_zero! on a function with no enclosed data would just be a no op.

wsmoses commented 1 month ago

It could have enclosed data though.

Eg make_zero!(2.9) or something

On Sun, Jul 21, 2024 at 9:40 PM Christopher Rackauckas < @.***> wrote:

I would assume the behavior of make_zero! on a function with no enclosed data would just be a no op.

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1661#issuecomment-2241887847, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXGNAFUZF2MBLMODGULZNRPKLAVCNFSM6AAAAABLHLDQ3KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENBRHA4DOOBUG4 . You are receiving this because you commented.Message ID: <EnzymeAD/Enzyme. @.***>

ChrisRackauckas commented 1 month ago

If the data cannot be mutated then it can safely be skipped though. The use case here is that make_zero(f) is used to make the shadow function and then Duplicated(f,make_zero!(df)) is used to reset the caches each time for safety in the the shadow data. But if the caches cannot be mutated, then they don't need to be zero'd.

wsmoses commented 1 month ago

Makezero! Internally uses a routine that indeed only updates the mutable parts. Presently, however we define the semantics of make zero! To zero all differentiable data and err otherwise (as it does here). If you were to pass the df as a duplicated in here you would get the wrong answer if it wasn’t zero’s fully

ChrisRackauckas commented 1 month ago

Okay then we're missing a utility for generically handling functions correctly, since duplicated functions have this behavior where I want to set things to zero before reusing, unless the values are not writable (because of course that means the last pass hasn't changed them)