jonniedie / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
291 stars 35 forks source link

Add utility function to reorder ComponentArrays #171

Closed nrontsis closed 1 year ago

nrontsis commented 1 year ago

Summary

This PR adds a utility function that reorders ComponentArrays as per another prototype.

Example Output

a = ComponentVector(a=1, b=2)
b = ComponentVector(b=10, a=20)
reorder_as(a, b)

gives

ComponentVector(a=20, b=10)

Motivation and Example Usages:

Suppose we want to compute the dot product between two ComponentVectors, while caring about the order of the elements, as defined by the symbols of their axis.

That is we want my_dot(ComponentVector(a=1, b=0), ComponentVector(b=0, a=1)) to give a1*a2 + b1*b2=1*1 + 0*0=1 instead of 0 that LinearAlgebra.dot currently returns.

This utility function allows to do this in a generic way like following:

my_dot(v1::ComponentArray, v2::ComponentArray) = dot(v1, reorder_as(v1, v2)))
nrontsis commented 1 year ago

Closing, as I realised this function was returning incorrect results.

In my use cases, I resorted to constraining the types of the relevant component arrays to be identical, which guarantees matching axes. For example:

my_dot(v1::T, v2::T) where {T<:ComponentArray} = dot(v1, v2)

For future reference my latest attempt for reorder_as was:

function reorder_as(prototype::ComponentArray, input::ComponentArray)::ComponentArray
    @assert size(prototype) == size(input) "Incompatible sizes: $(size(prototype)) and $(size(input))"
    for (ax1, ax2) in zip(getaxes(prototype), getaxes(input))
        @assert (ax2 == ax1 || Set(keys(ax1)) == Set(keys(ax2))) "Incompatible axis: $ax1 and $ax2"
    end
    result = ComponentArray(similar(getdata(input)), getaxes(prototype)...)
    @views result[keys.(getaxes(prototype))...] = input[keys.(getaxes(prototype))...]
    return reorder_as.(prototype, result)
end
reorder_as(_, input) = input # No reordering to be done for non-ComponentArrays

but I gave up on testing this thoroughly. Furthermore, it was mutating so it would not work e.g. with Zygote.