Open scheidan opened 3 years ago
Zygote is rather easy to fix for this. It just needs literal getproperty rules like
https://github.com/SciML/RecursiveArrayTools.jl/blob/master/src/zygote.jl#L36-L41
which are actually just identity.
Just to update this issue. The example given above works now for ForwardDiff, but Zugote still fails with the same error.
Great to see some progress!
[f6369f11] ForwardDiff v0.10.22
[2ee39098] LabelledArrays v1.6.5
[e88e6eb3] Zygote v0.6.29
To address the particular error from above, I believe the following should do it:
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(getproperty), A::LArray, s::Symbol)
function getproperty_LArray_adjoint(d)
# NOTE: I hope this reference to `A` is optimized away.
Δ = similar(A) .= 0
setproperty!(Δ, s, d)
return (NoTangent(), Δ, NoTangent())
end
return getproperty(A, s), getproperty_LArray_adjoint
end
You might also run into issues with missing adjoints for constructor when using Zygote with something like @LArray ...
, in which case the following should fix your issue:
function ChainRulesCore.rrule(::Type{LArray{S}}, x::AbstractArray) where {S}
# Sometimes we're pulling back gradients which are not `LArray`.
constructor_LArray_adjoint(Δx::AbstractArray) = NoTangent(), Δx
constructor_LArray_adjoint(Δlx::LArray) = NoTangent(), Δlx.__x
return LArray{S}(x), constructor_LArray_adjoint
end
You'd also need a similar one for SLArray
.
@torfjelde fantastic. Could you open a PR with those?
Sure can :+1:
LabelledArrays looks very promising to simplify model definitions. Unfortunately it seem largely incompatible with AD. Is there a good workaround?
(In case of
ForwardDiff
this is probably related to #68)