EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
422 stars 58 forks source link

Autodiff for function that builds ODEProblem #1458

Closed m-bossart closed 1 month ago

m-bossart commented 1 month ago

I'm interested in differentiating a function that builds and solves an ODEProblem within the function. The simplest example I can come up with that does this is giving a nothing for the derivative. Am I making a mistake or is this unexpected behavior? Minimal example:

using OrdinaryDiffEq
using SciMLSensitivity
using Enzyme
using Zygote

odef(du, u, p, t) = du .= u .* p
function f(u0p)
    prob = ODEProblem{true}(odef, u0p[1:1], (0.0, 1.0), u0p[2:2])
    sum(solve(prob, Tsit5(), abstol = 1e-12, reltol = 1e-12, saveat = 0.1))
end
u0p = [2.0, 3.0]
du0p = zeros(2)

Zygote.gradient(f, u0p ./ 2)[1] #Zygote works 
##
Enzyme.autodiff(Reverse, f, Active, Duplicated(u0p, du0p))  #Enzyme gives nothing 
Enzyme.autodiff(Forward, f, Duplicated(u0p, du0p))          #Enzyme warns: You may be using a constant variable as temporary storage for active memory
m-bossart commented 1 month ago

Stupid mistake. Forgot to check du0p and not the return value (I'm new to Enzyme and used to the Zygote API).