jonniedie / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
287 stars 34 forks source link

Implementation of a `merge` function #186

Open vboussange opened 1 year ago

vboussange commented 1 year ago

Hey there, I thought it'd be very cool to have a specialisation of the merge function for ComponentArrays, with a similar behaviour as with Dictionarys or NamedTuples.

For instance, consider

N = 2
p = (r = ones(N), b = 3 * ones(N))
p2 = (r = 2 * ones(N))
merge(p, p2) #outputs (r = [2.0, 2.0], b = [3.0, 3.0])

It could be useful to have a similar functionality with p and p2 being ComponentArrays. I am not sure whether this could be efficiently implemented though.

jonniedie commented 1 year ago

Yeah, that would be nice. You can also accomplish this by doing

julia> p = ComponentArray(r = ones(N), b = 3 * ones(N))
ComponentVector{Float64}(r = [1.0, 1.0], b = [3.0, 3.0])

julia> p2 = ComponentArray(r = 2 * ones(N))
ComponentVector{Float64}(r = [2.0, 2.0])

julia> ComponentArray(p; p2...)
ComponentVector{Float64}(r = [2.0, 2.0], b = [3.0, 3.0])
jonniedie commented 1 year ago

I don't know that it is especially efficient, though

vboussange commented 1 year ago

The problem with the code that you propose is that it is not compatible with Zygote. Here is a version that works with Zygote:

import Base
function Base.merge(ca::ComponentArray{T}, ca2::ComponentArray{T}) where T
    ax = getaxes(ca)
    ax2 = getaxes(ca2)
    vks = valkeys(ax[1])
    vks2 = valkeys(ax2[1])
    _p = Vector{T}()
    for vk in vks
        if vk in vks2
            _p = vcat(_p, ca2[vk])
        else
            _p = vcat(_p, ca[vk])
        end
    end
    ComponentArray(_p, ax)
end

However, here it is assumed that ca and ca2 are ComponentVector without nested fields. One needs to improve this piece of code so that it works in the general case where ca and ca2 are general ComponentArrays, and have nested fields.

scheidan commented 1 year ago

see also #69