Open ChrisRackauckas opened 1 month ago
Int inference here appears to be resolved by: https://github.com/EnzymeAD/Enzyme.jl/pull/1575
using Enzyme, OrdinaryDiffEq, StaticArrays
Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_from_initdt!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.fixed_t_for_floatingpoint_error!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_reject!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(DiffEqBase.fastpow), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_perform_step!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.check_error!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.log_step!), args...) = true
function lorenz!(du, u, p, t)
du[1] = 10.0(u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end
const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]
function f(y::Array{Float64}, u0::Array{Float64})
tspan = (0.0, 3.0)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
y .= sol[1,:]
return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y = zeros(13)
dy = zeros(13)
Enzyme.autodiff(Reverse, f, Duplicated(y, dy), Duplicated(u0, d_u0));
reproduces the error on Enzyme main with the PR branch https://github.com/SciML/OrdinaryDiffEq.jl/pull/2282. Inactivating the remaining += 1
operation
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_accept!), args...) = true
fixes the differentiation, so this shows that it's just that one operation (well, the other instances of the same issue, but are set to be inactive).
Should this be reopened?
Done [and slightly retitled this since ironically its the other way round since active is our conservative state]
MWE:
The core of the error is:
which points to this part of the LLVM:
Full text: enzyme_error_2.txt
It's pointing to:
what's weirder is that
integrator.stats
is aand notice I had added
Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true
so presumably by my understandingintegrator.stats.nf += 1 #
should be inactive both because it is an integer, and because the type is made inactive. And Julia infers the type here, so I don't know what's going on in Enzyme đŸ˜…