SciML / LabelledArrays.jl

Arrays which also have a label for each element for easy scientific machine learning (SciML)
https://docs.sciml.ai/LabelledArrays/stable/
Other
120 stars 21 forks source link

AD (Zygote, ForwardDiff) compatibility? #91

Open scheidan opened 3 years ago

scheidan commented 3 years ago

LabelledArrays looks very promising to simplify model definitions. Unfortunately it seem largely incompatible with AD. Is there a good workaround?

using LabelledArrays
import ForwardDiff
import Zygote

model(p) =  p.a + p.b^2 + p.c^3

p = LVector(a=1, b=2, c=3)
ps = SLVector(a=1, b=2, c=3)
model(ps)
model(p)

ForwardDiff.gradient(model, p) # works :)
Zygote.gradient(model, p)      # ERROR: ArgumentError: invalid index: Val{:c}() of type Val{:c}

ForwardDiff.gradient(model, ps) # ERROR: type SArray has no field a
Zygote.gradient(model, ps)      # ERROR: ArgumentError: invalid index: Val{:c}() of type Val{:c}

(In case of ForwardDiff this is probably related to #68)

ChrisRackauckas commented 3 years ago

Zygote is rather easy to fix for this. It just needs literal getproperty rules like

https://github.com/SciML/RecursiveArrayTools.jl/blob/master/src/zygote.jl#L36-L41

which are actually just identity.

scheidan commented 2 years ago

Just to update this issue. The example given above works now for ForwardDiff, but Zugote still fails with the same error.

Great to see some progress!

  [f6369f11] ForwardDiff v0.10.22
  [2ee39098] LabelledArrays v1.6.5
  [e88e6eb3] Zygote v0.6.29
torfjelde commented 2 years ago

To address the particular error from above, I believe the following should do it:

using ChainRulesCore

function ChainRulesCore.rrule(::typeof(getproperty), A::LArray, s::Symbol)
    function getproperty_LArray_adjoint(d)
        # NOTE: I hope this reference to `A` is optimized away.
        Δ = similar(A) .= 0
        setproperty!(Δ, s, d)
        return (NoTangent(), Δ, NoTangent())
    end
    return getproperty(A, s), getproperty_LArray_adjoint
end

You might also run into issues with missing adjoints for constructor when using Zygote with something like @LArray ..., in which case the following should fix your issue:

function ChainRulesCore.rrule(::Type{LArray{S}}, x::AbstractArray) where {S}
    # Sometimes we're pulling back gradients which are not `LArray`.
    constructor_LArray_adjoint(Δx::AbstractArray) = NoTangent(), Δx
    constructor_LArray_adjoint(Δlx::LArray) = NoTangent(), Δlx.__x
    return LArray{S}(x), constructor_LArray_adjoint
end

You'd also need a similar one for SLArray.

ChrisRackauckas commented 2 years ago

@torfjelde fantastic. Could you open a PR with those?

torfjelde commented 2 years ago

Sure can :+1: