jonniedie / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
287 stars 34 forks source link

Missing ChainRule for `ComponentVector(a; b...)` #207

Open scheidan opened 1 year ago

scheidan commented 1 year ago

I've tried to use Zygote with ComponentArrays but cannot it cannot get through this code:

c = ComponentVector(a; b...)

I think it is just a missing ChainRule rule. I've tired, but unfortunately writing rules is black magic for me...

scheidan commented 1 year ago

Here is a (failing) attempt of mine:

using ComponentArrays
import ChainRulesCore
import Zygote

# -----------
# rule

function ChainRulesCore.rrule(::typeof(ComponentArrays.ComponentVector),
                              x::ComponentVector; kwargs...)
    res = ComponentVector(x; kwargs...)
    function pullback(Δ)
        one_x = zero(similar(x, eltype(Δ))) .+ 1
        one_y = zero(ComponentVector{eltype(Δ)}(kwargs)) .+ 1
        return ChainRulesCore.NoTangent(), one_x, one_y
    end
    return res, pullback
end

# -----------
# test

function mymerge(x::ComponentVector, y::ComponentVector)
    z = ComponentVector(x; y...)
    z
end

x = ComponentVector(a=1.0, b=2, c=(e=3, f=4))
y = ComponentVector(a = 11, e=4.0, d=5.0)
mymerge(x, y)

Zygote.gradient(a -> sum(mymerge(a, y)), x)[1] # fails with StackOverflowError
Zygote.gradient(a -> sum(mymerge(x, a)), y)[1] # fails with StackOverflowError

Not sure why this is causing a StackOverflowError. A test version without the kwargs seemed to work.

jonniediegelman commented 1 year ago

It looks like you were super close. You just needed to splat out the keyword arguments in the pullback.

function ChainRulesCore.rrule(::typeof(ComponentArrays.ComponentVector),
                              x::ComponentVector; kwargs...)
    res = ComponentVector(x; kwargs...)
    function pullback(Δ)
        one_x = zero(similar(x, eltype(Δ))) .+ 1
        one_y = zero(ComponentVector{eltype(Δ)}(; kwargs...)) .+ 1
        return ChainRulesCore.NoTangent(), one_x, one_y
    end
    return res, pullback
end

Thanks, though! I'll add it as soon as I get a chance.

jonniediegelman commented 1 year ago

Wait no, that gives the wrong answer.

jonniedie commented 12 months ago

Interesting: ChainRules doesn't work with keyword arguments. We may want to instead define the behavior in a merge method so it's compatible with ChainRules.