FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 209 forks source link

Zygote DiffEq Integration (Pun Intended) #37

Open ChrisRackauckas opened 5 years ago

ChrisRackauckas commented 5 years ago

MWE:

using OrdinaryDiffEq, ParameterizedFunctions, ForwardDiff

f = @ode_def begin
  dx = a*x - b*x*y
  dy = -c*y + x*y
end a b c

p = [1.5,1.0,3.0]
prob = ODEProblem(f,[1.0;1.0],(0.0,10.0),p)
t = 0.0:0.5:10.0

function G(p)
  tmp_prob = remake(prob,u0=convert.(eltype(p),prob.u0),p=p)
  sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=t)
  A = convert(Array,sol)
  sum(((1 .- A).^2)./2)
end
G([1.5,1.0,3.0])
res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0])

using Zygote
Zygote.derivative(G,[1.5,1.0,3.0])
Krastanov commented 5 years ago

Does this still work?

I deleted the ForwardDiff stuff for this test.

On a brand new install of julia 1.0.3 with add DifferentialEquations Zygote I simply get a crash when Zygote.derivative is called (no error messages, the notebook kernel simply crashes, with no explanation or a segfault on the terminal).

With add DifferentialEquations#master Zygote#master I get some error about dual numbers:

MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing,Float64,2})
  ...
Stacktrace:
 [1] convert(::Type{Float64}, ::ForwardDiff.Dual{Nothing,Float64,2}) at ./number.jl:7
 [2] (::getfield(Zygote, Symbol("##819#822")){typeof(convert)})(::Type, ::Float64) at /home/stefan/.julia/packages/Zygote/Ohw1K/src/lib/broadcast.jl:113

I even tried a much simplified piece of code, which also fails:

f(x,p,t) = p*x
t = 0:0.1:1
function sol(a)
    pr = ODEProblem(f,1.,(0.,1.),a)
    s = solve(pr,Euler(),dt=0.01, saveat=t)
    s.u
end
Zygote.gradient(a->sum(sol(a)),1.)

I get with add DifferentialEquations#master Zygote#master:

Compiling Tuple{getfield(DiffEqBase, Symbol("##solve#442")),Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:dt, :saveat),Tuple{Float64,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}},typeof(solve),ODEProblem{Float64,Tuple{Float64,Float64},false,Float64,ODEFunction{false,typeof(f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Euler}: BoundsError: attempt to access 0-element UnitRange{Int64} at index [1]

Stacktrace:
 [1] throw_boundserror(::UnitRange{Int64}, ::Int64) at ./abstractarray.jl:484
 [2] getindex at ./range.jl:597 [inlined]
 [3] getindex at /home/stefan/.julia/packages/IRTools/Y9ACs/src/ir/wrap.jl:23 [inlined]
 [4] first at ./abstractarray.jl:270 [inlined]
Krastanov commented 5 years ago

I got some debug information from dmesg of all places.

The release version of Zygote crashes with this segfault [111742.546675] julia[28661]: segfault at 10 ip 00007fe630caedc1 sp 00007ffefa6c5440 error 4 in libjulia.so.1.0[7fe630bd7000+236000].

The master branch of Zygote just reports the errors mentioned above.

ChrisRackauckas commented 5 years ago

No, it never worked. The MWE is a minimal piece of code that errors for debugging it. The purpose of this issue is to find out how to make it work. Sorry if that wasn't clear.

ChrisRackauckas commented 3 years ago

Updated MWE:

using DiffEqSensitivity, OrdinaryDiffEq, Zygote

function fiip(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end
function foop(u,p,t)
  dx = p[1]*u[1] - p[2]*u[1]*u[2]
  dy = -p[3]*u[2] + p[4]*u[1]*u[2]
  [dx,dy]
end

p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
prob = ODEProblem(fiip,u0,(0.0,10.0),p)
proboop = ODEProblem(foop,u0,(0.0,10.0),p)

Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ZygoteAdjoint())),u0,p)
Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=SensitivityADPassThrough())),u0,p)

# Harder!
Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=ZygoteAdjoint())),u0,p)
Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,abstol=1e-14,reltol=1e-14,saveat=0.1,sensealg=SensitivityADPassThrough())),u0,p)