JuliaDiff / ChainRulesTestUtils.jl

Utilities for testing custom AD primitives.
MIT License
50 stars 15 forks source link

`test_add!!_behaviour` has strong assumptions on fields #267

Closed theogf closed 1 year ago

theogf commented 1 year ago

When calling test_rrule on a struct containing Tuple as its fields (e.g. StructArray or SArray), _test_cotangent will fail due to the impossibility of adding Tuple together.

Using broadcast on this line: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/ed9a0073ff83cb3b1f4619303e41f4dd5d8c4825/src/tangent_types/tangent.jl#L301 would solve the issue I think.

Here is a MWE.

using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
    collect_rrule(Δ::AbstractArray) = NoTangent(), Tangent{X}(data = ntuple(i -> Δ[i], Val(L)))
    return collect(x), collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)
oxinabox commented 1 year ago

While broadcast would fix that specific case, i think better is a call to elementwise_add which would do the same thing in the tuple case. But I can imagine for something with other types inside it, like NamedTuple broadcast would work. Further if one side as has a scalar and other other a Tuple, then broadcast would "work" which it shouldn't.

However, I am not 100% sure this is actually a problem. Since the example type is not a correct tangent type. Since it has a field which is not a valid tangent type, due to not implementing a vector space. (Which requires overloading +) I believe, the correct code is

using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
    collect_rrule(Δ::AbstractArray) = NoTangent(), Tangent{X}(data = Tangent{NTuple{T, N}}(ntuple(i -> Δ[i], Val(L))...))
    return collect(x), collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)

Or much simpler use a natural tangent:

using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
    collect_rrule(Δ::AbstractArray) = NoTangent(),  Δ
    return collect(x), collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)

Strictly speaking one should use projection:

using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
    y = collect(x)
    proj = ProjectTo(y)
    collect_rrule(Δ::AbstractArray) = NoTangent(),  proj(Δ)
    return y, collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A)
theogf commented 1 year ago

Thanks for the investigation, it is really helpful!

I have indeed started to look more into ProjectTo I tended to forget that StructArrays are well... arrays!