Open scheidan opened 1 year ago
Here is a (failing) attempt of mine:
using ComponentArrays
import ChainRulesCore
import Zygote
# -----------
# rule
function ChainRulesCore.rrule(::typeof(ComponentArrays.ComponentVector),
x::ComponentVector; kwargs...)
res = ComponentVector(x; kwargs...)
function pullback(Δ)
one_x = zero(similar(x, eltype(Δ))) .+ 1
one_y = zero(ComponentVector{eltype(Δ)}(kwargs)) .+ 1
return ChainRulesCore.NoTangent(), one_x, one_y
end
return res, pullback
end
# -----------
# test
function mymerge(x::ComponentVector, y::ComponentVector)
z = ComponentVector(x; y...)
z
end
x = ComponentVector(a=1.0, b=2, c=(e=3, f=4))
y = ComponentVector(a = 11, e=4.0, d=5.0)
mymerge(x, y)
Zygote.gradient(a -> sum(mymerge(a, y)), x)[1] # fails with StackOverflowError
Zygote.gradient(a -> sum(mymerge(x, a)), y)[1] # fails with StackOverflowError
Not sure why this is causing a StackOverflowError. A test version without the kwargs
seemed to work.
It looks like you were super close. You just needed to splat out the keyword arguments in the pullback.
function ChainRulesCore.rrule(::typeof(ComponentArrays.ComponentVector),
x::ComponentVector; kwargs...)
res = ComponentVector(x; kwargs...)
function pullback(Δ)
one_x = zero(similar(x, eltype(Δ))) .+ 1
one_y = zero(ComponentVector{eltype(Δ)}(; kwargs...)) .+ 1
return ChainRulesCore.NoTangent(), one_x, one_y
end
return res, pullback
end
Thanks, though! I'll add it as soon as I get a chance.
Wait no, that gives the wrong answer.
Interesting: ChainRules doesn't work with keyword arguments. We may want to instead define the behavior in a merge
method so it's compatible with ChainRules.
I've tried to use Zygote with ComponentArrays but cannot it cannot get through this code:
I think it is just a missing ChainRule rule. I've tired, but unfortunately writing rules is black magic for me...