EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
437 stars 62 forks source link

Operation on Int not deduced as const #1636

Open ChrisRackauckas opened 1 month ago

ChrisRackauckas commented 1 month ago

MWE:

using Enzyme, OrdinaryDiffEq, StaticArrays

Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f(y::Array{Float64}, u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), adaptive=false, dt = 0.1, saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
    @show sol[1,:], sol.retcode
    y .= sol[1,:]
    return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y  = zeros(13)
dy = zeros(13)
Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));

The core of the error is:

cannot handle unknown binary operator:   %207 = add i64 %206, 1, !dbg !564

which points to this part of the LLVM:

 constantinst[  %204 = getelementptr inbounds i8, i8 addrspace(11)* %203, i64 80, !dbg !563] = 1 val:0 type: {[-1]:Pointer}
 constantinst[  %205 = bitcast i8 addrspace(11)* %204 to i64 addrspace(11)*, !dbg !563] = 1 val:0 type: {[-1]:Pointer}
 constantinst[  %206 = load i64, i64 addrspace(11)* %205, align 8, !dbg !563, !tbaa !202, !alias.scope !206, !noalias !224] = 0 val:0 type: {}
 constantinst[  %207 = add i64 %206, 1, !dbg !564] = 0 val:0 type: {}

Full text: enzyme_error_2.txt

It's pointing to:

function initialize!(integrator, cache::Tsit5Cache)
    integrator.kshortsize = 7
    integrator.fsalfirst = cache.k1
    integrator.fsallast = cache.k7 # setup pointers
    resize!(integrator.k, integrator.kshortsize)
    # Setup k pointers
    integrator.k[1] = cache.k1
    integrator.k[2] = cache.k2
    integrator.k[3] = cache.k3
    integrator.k[4] = cache.k4
    integrator.k[5] = cache.k5
    integrator.k[6] = cache.k6
    integrator.k[7] = cache.k7
    integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # Pre-start fsal
    integrator.stats.nf += 1 # points right here
    return nothing
end

what's weirder is that integrator.stats is a

mutable struct DEStats
    nf::Int
    nf2::Int
    nw::Int
    nsolve::Int
    njacs::Int
    nnonliniter::Int
    nnonlinconvfail::Int
    nfpiter::Int
    nfpconvfail::Int
    ncondition::Int
    naccept::Int
    nreject::Int
    maxeig::Float64
end

and notice I had added Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true so presumably by my understanding integrator.stats.nf += 1 # should be inactive both because it is an integer, and because the type is made inactive. And Julia infers the type here, so I don't know what's going on in Enzyme đŸ˜…

wsmoses commented 1 month ago

Int inference here appears to be resolved by: https://github.com/EnzymeAD/Enzyme.jl/pull/1575

ChrisRackauckas commented 1 month ago
using Enzyme, OrdinaryDiffEq, StaticArrays

Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_from_initdt!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.fixed_t_for_floatingpoint_error!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_reject!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(DiffEqBase.fastpow), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_perform_step!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.check_error!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.log_step!), args...) = true

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f(y::Array{Float64}, u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
    y .= sol[1,:]
    return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y  = zeros(13)
dy = zeros(13)

Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));

reproduces the error on Enzyme main with the PR branch https://github.com/SciML/OrdinaryDiffEq.jl/pull/2282. Inactivating the remaining += 1 operation

Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_accept!), args...) = true

fixes the differentiation, so this shows that it's just that one operation (well, the other instances of the same issue, but are set to be inactive).

ChrisRackauckas commented 1 month ago

Should this be reopened?

wsmoses commented 1 month ago

Done [and slightly retitled this since ironically its the other way round since active is our conservative state]