Closed theogf closed 1 year ago
While broadcast would fix that specific case, i think better is a call to elementwise_add
which would do the same thing in the tuple case.
But I can imagine for something with other types inside it, like NamedTuple
broadcast would work.
Further if one side as has a scalar and other other a Tuple, then broadcast would "work" which it shouldn't.
However, I am not 100% sure this is actually a problem.
Since the example type is not a correct tangent type.
Since it has a field which is not a valid tangent type, due to not implementing a vector space. (Which requires overloading +
)
I believe, the correct code is
using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
collect_rrule(Δ::AbstractArray) = NoTangent(), Tangent{X}(data = Tangent{NTuple{T, N}}(ntuple(i -> Δ[i], Val(L))...))
return collect(x), collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)
Or much simpler use a natural tangent:
using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
collect_rrule(Δ::AbstractArray) = NoTangent(), Δ
return collect(x), collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)
Strictly speaking one should use projection:
using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
y = collect(x)
proj = ProjectTo(y)
collect_rrule(Δ::AbstractArray) = NoTangent(), proj(Δ)
return y, collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)
Thanks for the investigation, it is really helpful!
I have indeed started to look more into ProjectTo
I tended to forget that StructArrays
are well... arrays!
When calling
test_rrule
on a struct containingTuple
as its fields (e.g.StructArray
orSArray
),_test_cotangent
will fail due to the impossibility of addingTuple
together.Using broadcast on this line: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/ed9a0073ff83cb3b1f4619303e41f4dd5d8c4825/src/tangent_types/tangent.jl#L301 would solve the issue I think.
Here is a MWE.