Open ChrisRackauckas opened 3 years ago
@oxinabox for comfort.
The current blocking issues:
literal_getproperty
. Keeping those to ZygoteRules.@adjoint
for nowmap
rules to not require __context__
ZygoteRules.@adjoint function VectorOfArray(u)
-> function ChainRulesCore.rrule(::VectorOfArray,u)
isn't caught? https://github.com/SciML/RecursiveArrayTools.jl/pull/144/checks?check_run_id=2878240138#step:6:241ZygoteRules.@adjoint function VectorOfArray(u)
-> function ChainRulesCore.rrule(::Type{<:VectorOfArray},u)
Downstream testing revealed that Zygote actually ignored these rules 🤦 so those need to get reverted.
The end result of this is that in order to pass tests, all packages needed to keep a few ZygoteRules so it doesn't seem possible at this time to use strictly ChainRules for all of this. This should be double checked in the future (with Diffractor)
using OrdinaryDiffEq, DiffEqSensitivity, DiffEqFlux, LinearAlgebra, Flux
nn = FastChain(FastDense(1,16),FastDense(16,16,tanh),FastDense(16,2))
initial = initial_params(nn)
function ode2!(u, p, t)
f1, f2 = nn([t],p)
[-f1^2; f2]
end
tspan = (0.0, 10.0)
prob = ODEProblem(ode2!, Complex{Float64}[0;0], tspan, initial)
function loss(p)
sol = last(solve(prob, Tsit5(), p=p, sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())))
return norm(sol)
end
result_ode = DiffEqFlux.sciml_train(loss, initial, ADAM(0.1), maxiters = 100)
showed that Zygote skips
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i)
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
(NoTangent(),Δ′,NoTangent())
end
VA[i],AbstractVectorOfArray_getindex_adjoint
end
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i, j...)
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = zero(VA)
Δ′[i,j...] = Δ
(NoTangent(), Δ′, i,map(_ -> NoTangent(), j)...)
end
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
end
because it has its own getindex overloads.
Zygote ignores:
function ChainRulesCore.rrule(::Type{<:EnsembleGPUArray})
EnsembleGPUArray(0.0), _ -> NoTangent()
end
so I needed to also keep:
ZygoteRules.@adjoint function EnsembleGPUArray()
EnsembleGPUArray(0.0), _ -> nothing
end
function ChainRulesCore.rrule(f::ODEFunction,u,p,t)
if f.vjp === nothing
ChainRulesCore.rrule(f.f,u,p,t)
else
f.vjp(u,p,t)
end
end
was skipped, so I added back:
ZygoteRules.@adjoint function (f::ODEFunction)(u,p,t)
if f.vjp === nothing
ZygoteRules._pullback(f.f,u,p,t)
else
f.vjp(u,p,t)
end
end
ZygoteRules.@adjoint! function (f::ODEFunction)(du,u,p,t)
if f.vjp === nothing
ZygoteRules._pullback(f.f,du,u,p,t)
else
f.vjp(du,u,p,t)
end
end
The list of things to handle is:
The solve.jl one is a bit nasty because it calls all of the adjoints in DiffEqSensitivity.jl, so it's somewhat breaking, so we can... just do it really quickly.
@frankschae @YingboMa