JuliaDiff / ReverseDiff.jl

Reverse Mode Automatic Differentiation for Julia
Other
348 stars 57 forks source link

TrackedReal vs. TrackedArray in ReverseDiffAdjoint() for 1d oop ODEProblems #178

Open frankschae opened 3 years ago

frankschae commented 3 years ago

ReverseDiffAdjoint() works fine for oop ODEProblems with a state vector that is at least 2 dimensional:

tspan = (0.0,10.0)
p = [1.0,0.0]
function f(u,p,t)
  dx = p[1]*u[1] + p[2]
  dy = 0*u[2]
  [dx,dy]
end
u0 = [0.2,0.0]
proboop = ODEProblem{false}(f,u0,tspan,p)
Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,saveat=0.1,sensealg=ReverseDiffAdjoint())),u0,p)

When a one-dimensional state is simulated with the same syntax, e.g.

# failing example
function f(u,p,t)
  dx = p[1]*u[1] + p[2]
  [dx]
end
u0 = [0.2]
proboop = ODEProblem{false}(f,u0,tspan,p)
Zygote.gradient((u0,p)->sum(solve(proboop,Tsit5(),u0=u0,p=p,saveat=0.1,sensealg=ReverseDiffAdjoint())),u0,p)

a MethodError is thrown:

ERROR: MethodError: Cannot `convert` an object of type ReverseDiff.TrackedReal{Float64, Float64, Nothing} to an object of type ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}
Closest candidates are:
  convert(::Type{T}, ::T) where T<:ReverseDiff.TrackedArray at /Users/frank/.julia/packages/ReverseDiff/iHmB4/src/tracked.jl:270
  convert(::Type{T}, ::LinearAlgebra.Factorization) where T<:AbstractArray at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/factorization.jl:58
  convert(::Type{T}, ::T) where T<:AbstractArray at abstractarray.jl:14
  ...
Stacktrace:
  [1] setproperty!
...

@ChrisRackauckas traced it down to:

typeof(reduce(vcat, [dx])) = ReverseDiff.TrackedReal{Float64, Float64, Nothing}
typeof(reduce(vcat, [dx, dy])) = ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}

which is used for the array of structs -> struct of arrays conversion.