jonniedie / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
291 stars 35 forks source link

`map` returns a plain array when iterating multiple inputs #149

Closed YichengDWu closed 2 years ago

YichengDWu commented 2 years ago

Related to #57


julia> using ComponentArrays

julia> ca = ComponentArray(a=1, b=2);

julia> map(identity, ca)
ComponentVector{Int64}(a = 1, b = 2)

julia> map((x,y)->x+y, ca,1:2)
2-element Vector{Int64}:
 2
 4
YichengDWu commented 2 years ago

If I overload map and I can get around it


julia>  Base.map(f, x::ComponentArray, args...) = ComponentArray(map(f, getdata(x), args...), getaxes(x))

julia> map((x,y)->x+y, ca,[3,4])
ComponentVector{Int64}(a = 4, b = 6)
YichengDWu commented 2 years ago

This could be dangerous tho...

YichengDWu commented 2 years ago

Ok this is a bad idea. But the following idea is a good one

import Zygote: seed
using ForwardDiff: Dual

function Zygote.seed(x::ComponentArray, ::Val{N}, offset = 0) where N
    dual = map(x, reshape(1:length(x), size(x))) do x, i
            Dual(x, ntuple(j -> j+offset == i, Val(N)))
        end
    return ComponentArray(dual, getaxes(x))
end

This is closely related to #148

jonniedie commented 2 years ago

The behavior of map should fall out of broadcast rules, rather than be overloaded explicitly (because this wouldn’t cover the case where the 2nd, 3rd, etc. argument is a ComponentArray and the first isn’t. I could probably change this so the ComponentArray “wins” here.

YichengDWu commented 2 years ago

The current behaviour of map is fine as it is the same with named tuples.

julia> ca = (a=1,b=2)
(a = 1, b = 2)

julia> map((x,y)->x+y, ca,1:2)
2-element Vector{Int64}:
 2
 4
YichengDWu commented 2 years ago

I'm closing this since its more of Zygote issue