EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
443 stars 62 forks source link

autodiff fails when solving ODE if ODEFunction is used #1460

Closed m-bossart closed 4 months ago

m-bossart commented 4 months ago

This example of differentiating a function that builds and solves an ODEProblem works:

f(du, u, p, t) = du .= u .* p
function f(u0p)
    prob = ODEProblem{true}(f, u0p[1:1], (0.0, 1.0), u0p[2:2])
    sum(solve(prob, Rodas4(), abstol = 1e-12, reltol = 1e-12, saveat = 0.1))
end
u0p = [2.0, 3.0]
du0p = zeros(2)
f(u0p)
du0p_zygote = Zygote.gradient(f, u0p)[1]
Enzyme.autodiff(Reverse, f, Active, Duplicated(u0p, du0p))  
@test isapprox(du0p_zygote, du0p)

If instead, I build an ODEFunctionbased on f (with no additional inputs), then autodiff fails:

f(du, u, p, t) = du .= u .* p
function f(u0p)
    odef = ODEFunction{true}(f)
    prob = ODEProblem{true}(odef, u0p[1:1], (0.0, 1.0), u0p[2:2])
    sum(solve(prob, Rodas4(), abstol = 1e-12, reltol = 1e-12, saveat = 0.1))
end
u0p = [2.0, 3.0]
du0p = zeros(2)
f(u0p)
du0p_zygote = Zygote.gradient(f, u0p)[1]
Enzyme.autodiff(Reverse, f, Active, Duplicated(u0p, du0p))  
@test isapprox(du0p_zygote, du0p)

The start of the error indicates that enzyme cannot deduce the type due to this addition:

ERROR: Enzyme execution failed.
Enzyme cannot deduce type
Current scope:
; Function Attrs: mustprogress willreturn
...

This issue is similar to #1459 in that it attempts to cover more cases of solving ODEs within the SciML ecosystem. Please let me know if these issues belong in SciMLSensitivity.

m-bossart commented 4 months ago

As is probably expected, adding looseTypeAnalysis!(true) makes this work, however increasing the maxtypedepth or maxtypeoffset did not.

m-bossart commented 4 months ago

This can be mitigated by explicitly indicating the level of specialization when building the ODEFunction (e.g. ODEFunction{true, SciMLBase.AutoSpecialize})