jonniedie / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
287 stars 34 forks source link

Reexport `inv` and `diag` etc from `LinearAlgebra`? #187

Open Yuan-Ru-Lin opened 1 year ago

Yuan-Ru-Lin commented 1 year ago

Is ther any reason not to reexport inv and diag etc? They can implemented quite easily as follows.

import LinearAlgebra: inv, diag
inv(x::ComponentMatrix)::ComponentMatrix = ComponentMatrix(inv(getdata(x)), getaxes(x)...)
diag(x::ComponentMatrix)::ComponentVector = ComponentVector(diag(getdata(x)), getaxes(x)[1])
jonniedie commented 1 year ago

Well we probably wouldn't want to export them in order to not clog people's namespace (if they want diag they can do using LinearAlgebra). But I think the bigger part of the question is whether we should specifically overload these functions to return ComponentArrays instead of falling back to plain Arrays. It tends to be a huge headache to chase down every method of every function like this and make a special case, so I just don't do that. Usually I just rely on the functions to be properly calling similar, convert, or copy so ComponentArrays would be automatically created when they should.

With that said, it's actually a little trickier than the above to implement things like inv or diag. For example, if you had

ab = ComponentVector(a=1, b=2)
xy = ComponentVector(x=3, y=4)
abxy = ab * xy'

what would you expect the axes of diag(abxy) to look like? There's not really a good reason they should be a, b over x, y.

Yuan-Ru-Lin commented 1 year ago

I agree there is ambiguity in the case of diag. Maybe we can have something like ComponentSquareMatrix that enforces identical Axis in both dimensions?

As for the case of inv, it seems that the internal of LinearAlgebra does call convert, but it still falls back to ordinary matrix.

julia> using LinearAlgebra, ComponentArrays
julia> M = [1 2; 3 4];
julia> inv(ComponentMatrix(M, Axis(a=1, b=2), Axis(a=1, b=2)))
2×2 Matrix{Float64}:
  1.5  -0.5
 -2.0   1.0

while @edit inv(M) leads to

...
function inv(A::StridedMatrix{T}) where T
    checksquare(A)
    S = typeof((one(T)*zero(T) + one(T)*zero(T))/one(T))
    AA = convert(AbstractArray{S}, A)
    if istriu(AA)
        Ai = triu!(parent(inv(UpperTriangular(AA))))
    elseif istril(AA)
        Ai = tril!(parent(inv(LowerTriangular(AA))))
    else
        Ai = inv!(lu(AA))
        Ai = convert(typeof(parent(Ai)), Ai)
    end
    return Ai
end
...

Maybe I should open two separate issues with these?