SciML / SciMLBase.jl

The Base interface of the SciML ecosystem
https://docs.sciml.ai/SciMLBase/stable
MIT License
138 stars 99 forks source link

Updating SciML to ChainRules #69

Open ChrisRackauckas opened 3 years ago

ChrisRackauckas commented 3 years ago

The list of things to handle is:

The solve.jl one is a bit nasty because it calls all of the adjoints in DiffEqSensitivity.jl, so it's somewhat breaking, so we can... just do it really quickly.

@frankschae @YingboMa

ChrisRackauckas commented 3 years ago

@oxinabox for comfort.

ChrisRackauckas commented 3 years ago

https://github.com/SciML/DiffEqBase.jl/pull/674 https://github.com/SciML/Quadrature.jl/pull/66

ChrisRackauckas commented 3 years ago

https://github.com/SciML/RecursiveArrayTools.jl/pull/144 https://github.com/SciML/DiffEqGPU.jl/pull/105

ChrisRackauckas commented 3 years ago

The current blocking issues:

ChrisRackauckas commented 3 years ago

https://github.com/SciML/DiffEqGPU.jl/blob/master/src/DiffEqGPU.jl#L160-L162 doesn't get called, so https://github.com/SciML/DiffEqGPU.jl/blob/master/src/DiffEqGPU.jl#L164-L166 is required.

ChrisRackauckas commented 3 years ago

ZygoteRules.@adjoint function VectorOfArray(u) -> function ChainRulesCore.rrule(::Type{<:VectorOfArray},u)

Downstream testing revealed that Zygote actually ignored these rules 🤦 so those need to get reverted.

ChrisRackauckas commented 3 years ago

The end result of this is that in order to pass tests, all packages needed to keep a few ZygoteRules so it doesn't seem possible at this time to use strictly ChainRules for all of this. This should be double checked in the future (with Diffractor)

ChrisRackauckas commented 3 years ago
using OrdinaryDiffEq, DiffEqSensitivity, DiffEqFlux, LinearAlgebra, Flux
nn = FastChain(FastDense(1,16),FastDense(16,16,tanh),FastDense(16,2))
initial = initial_params(nn)

function ode2!(u, p, t)
    f1, f2 = nn([t],p)
    [-f1^2; f2]
end

tspan = (0.0, 10.0)
prob = ODEProblem(ode2!, Complex{Float64}[0;0], tspan, initial)

function loss(p)
  sol = last(solve(prob, Tsit5(), p=p, sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())))
  return norm(sol)
end

result_ode = DiffEqFlux.sciml_train(loss, initial, ADAM(0.1), maxiters = 100)

showed that Zygote skips

function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i)
  function AbstractVectorOfArray_getindex_adjoint(Δ)
    Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
    (NoTangent(),Δ′,NoTangent())
  end
  VA[i],AbstractVectorOfArray_getindex_adjoint
end

function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i, j...)
  function AbstractVectorOfArray_getindex_adjoint(Δ)
    Δ′ = zero(VA)
    Δ′[i,j...] = Δ
    (NoTangent(), Δ′, i,map(_ -> NoTangent(), j)...)
  end
  VA[i,j...],AbstractVectorOfArray_getindex_adjoint
end

because it has its own getindex overloads.

ChrisRackauckas commented 3 years ago

Zygote ignores:

function ChainRulesCore.rrule(::Type{<:EnsembleGPUArray})
    EnsembleGPUArray(0.0), _ -> NoTangent()
end

so I needed to also keep:

ZygoteRules.@adjoint function EnsembleGPUArray()
    EnsembleGPUArray(0.0), _ -> nothing
end
ChrisRackauckas commented 3 years ago
function ChainRulesCore.rrule(f::ODEFunction,u,p,t)
  if f.vjp === nothing
    ChainRulesCore.rrule(f.f,u,p,t)
  else
    f.vjp(u,p,t)
  end
end

was skipped, so I added back:

ZygoteRules.@adjoint function (f::ODEFunction)(u,p,t)
  if f.vjp === nothing
    ZygoteRules._pullback(f.f,u,p,t)
  else
    f.vjp(u,p,t)
  end
end

ZygoteRules.@adjoint! function (f::ODEFunction)(du,u,p,t)
  if f.vjp === nothing
    ZygoteRules._pullback(f.f,du,u,p,t)
  else
    f.vjp(du,u,p,t)
  end
end