JuliaLang / LinearAlgebra.jl

Julia Linear Algebra standard library
Other
18 stars 5 forks source link

Feature request: multidimensional `LinearAlgebra.tr` #958

Open MilesCranmer opened 2 years ago

MilesCranmer commented 2 years ago

Right now, LinearAlgebra.tr only traces second-order tensors. However, it can be very useful to trace along a specific set of axes (e.g., for batched calculations), as well as ND traces (e.g., x[i, i, i]). This can be done with Tullio.jl for static dimension specification, but when you need to specify which axes are being traced, it requires some meta-programming. So it would be nicer if tr itself could take in a dims argument, and also handle more than two axes.

Therefore, I propose to extend tr to allow for a dims argument, which can be used for arbitrary dimension arrays. The following code gives a simple working implementation:

"""Perform the calculation sum_{i} x[i, i, ..., i]"""
function _tr_all_dims(x::AbstractArray{T,N}) where {T,N}
     sum(i -> x[fill(i, N)...], axes(x, 1))
 end

function tr(x::AbstractArray; dims)
    mapslices(_tr_all_dims, x; dims=dims)
end

So now we can do things like:

x = zeros(5, 5, 5)
for i=1:5
    x[i, i, i] = 1
end
# 3D trace:
tr(x; dims=(1, 2, 3))  # 5

# batched trace:
tr(x; dims=(1, 2))  # vector of 5

Let me know what you think.

Also - this implementation (with splats) is very slow, so we'd probably want to compile-in specialization for the dimensionality of the array.

Thanks, Miles

MilesCranmer commented 2 years ago

Here's a significantly faster version:

@generated function _tr_all_dims(x::AbstractArray{T, N}) where {T,N}
    x_part = :(x[$(fill(:i, N)...)])
    quote
        out = zero($T)
        for i in axes(x, 1)
            out += $x_part
        end
        out
    end
end

function tr(x::AbstractArray; dims)
    mapslices(_tr_all_dims, x; dims=dims)
end

although mapslices is quite slow for some reason.

MilesCranmer commented 2 years ago

Bit faster, but still slower than tr

@generated function _tr_dims(x::AbstractArray{T,N}; dims) where {T,N}
    # val_dims is a tuple of Val(i), Val(j), etc.
    dims = collect(val.parameters[1] for val in dims.parameters)
    indices = [
        (j in dims) ? :(i) : :(:) for j in 1:N
    ]
    x_part = :(x[$(indices...)])
    summation = if :(:) in indices
        :(out .+= $x_part)
    else
        :(out += $x_part)
    end
    quote
        i = first(axes(x, 1))
        out = zero($x_part)
        for i in axes(x, 1)
            $summation
        end
        out
    end
end

function LinearAlgebra.tr(x; dims)
    dims = Tuple(collect(Val(dim) for dim in dims))
    return _tr_dims(x; dims=dims)
end
MilesCranmer commented 2 years ago

Actually for big arrays, they are basically comparable. I think the multi-dimensional trace simply has more things to do at startup.

dkarrasch commented 2 years ago

Perhaps you will want to replace fill(i, N) and alike by ntuple(_ -> i, N). That makes the allocations disappear.

MilesCranmer commented 2 years ago

Thanks. Here's the updated performances on a second-order tensor. The baseline, with normal tr, is 1.18 us

  1. 1.02 ms
    function tr1(x::AbstractArray{T,N}; dims) where {T,N}
    return mapslices(_x -> sum(i -> x[ntuple(_ -> i, N)...], axes(_x, 1)), x; dims=dims)
    end
  2. 1.02 ms
    @generated function _tr_all_dims(x::AbstractArray{T, N}) where {T,N}
    x_part = :(x[$(fill(:i, N)...)])
    quote
        out = zero($T)
        for i in axes(x, 1)
            out += $x_part
        end
        out
    end
    end
    function tr2(x::AbstractArray; dims)
    mapslices(_tr_all_dims, x; dims=dims)
    end
  3. 2.11 us
    @generated function _tr3_dims(x::AbstractArray{T,N}; dims) where {T,N}
    # val_dims is a tuple of Val(i), Val(j), etc.
    dims = collect(val.parameters[1] for val in dims.parameters)
    indices = [
        (j in dims) ? :(i) : :(:) for j in 1:N
    ]
    x_part = :(x[$(indices...)])
    summation = if :(:) in indices
        :(out .+= $x_part)
    else
        :(out += $x_part)
    end
    quote
        i = first(axes(x, 1))
        out = zero($x_part)
        for i in axes(x, 1)
            $summation
        end
        out
    end
    end
    function tr3(x; dims)
    return _tr3_dims(x; dims=Tuple(collect(Val(dim) for dim in dims)))
    end
  4. 1.583 us (Same as 3., but slicing a tuple of Val(i) rather than creating)
    function tr4(x::AbstractArray{T,N}; dims) where {T,N}
    possible_dims = ntuple(i -> Val(i), N)
    selected_dims = Tuple(collect(possible_dims[dim] for dim in dims))
    return _tr3_dims(x; dims=selected_dims)
    end
MilesCranmer commented 2 years ago

@dkarrasch let me know what you think, and I can make a PR.

dkarrasch commented 2 years ago

I'm finding the metaprogramming stuff quite hard to read. A simple

using LinearAlgebra
import LinearAlgebra: checksquare

function checksquare(A::AbstractArray{<:Any,N}) where {N}
    sz = size(A)
    all(==(sz[1]), sz) || throw(DimensionMismatch("array is not square: dimensions are $(size(A))"))
    sz[1]
end

function mytr(A::Array{T,N}) where {T,N}
    n = checksquare(A)
    t = zero(T)
    @inbounds @simd for i in 1:n
        t += A[ntuple(_ -> i, N)...]
    end
    t
end

performs even better (due to the @simd annotation, which we should also invest on the current method in dense.jl; done in JuliaLang/julia#47585) than LinearAlgebra.tr on matrices, and its performance depends only on the number of diagonal elements. I tested on 400x400 and 400x400x400 arrays. As for the batched computation, mapslices seems to add massive overhead, but on the other hand I'm not sure that if we have this machinery, then we cook up some complicated super-specialized code here just for tr? @mbauman, do you have some good advice here?

mcabbott commented 2 years ago

It's not so complicated to make CartesianIndices do this.

tr5(A::AbstractArray; dims=:) = _mytr(dims, A)
function _mytr(dims::Tuple{Integer, Vararg{Integer}}, A::AbstractArray)
    dimaxes = map(d -> axes(A,d), dims)
    allequal(dimaxes) || _tr_error(dimaxes)
    mask = ntuple(d -> !(d in dims), ndims(A))
    B = similar(A, ifelse.(mask, axes(A), (Base.OneTo(1),)))
    for I in CartesianIndices(B)
        t = zero(eltype(A))
        @inbounds @simd for j in first(dimaxes)
            K = CartesianIndex(ifelse.(mask, Tuple(I), j))
            t += A[K]
        end
        B[I] = t
    end
    dropdims(B; dims)
end
@noinline _tr_error(ax) = throw(DimensionMismatch(
    "traced dimensions must agree, but got $(ax)"))
```julia function _mytr(::Colon, A::AbstractArray) allequal(axes(A)) || _tr_error(axes(A)) t = zero(eltype(A)) @inbounds @simd for i in axes(A,1) t += A[ntuple(_ -> i, ndims(A))...] end t end _mytr(dims::Integer, A::AbstractArray) = dropdims(sum(A; dims); dims) _mytr(::Colon, A::AbstractVector) = sum(A) @noinline _tr_error(ax) = throw(DimensionMismatch("traced dimensions must agree, but got $(ax)")) x = rand(1:10, 3,3,2) tr5(x, dims=(1,2)) let n = 100 dims = (2,3) x = randn(n,n,n) a = @btime tr2($x; dims=$dims) # mapslices b = @btime tr3($x; dims=$dims) # @generated c = @btime tr5($x; dims=$dims) # CartesianIndices a ≈ b ≈ c end # 1.213 ms (229 allocations: 83.02 KiB) # 16.708 μs (106 allocations: 89.42 KiB) # 3.792 μs (4 allocations: 992 bytes) let n = 100 x = randn(n,n,n) a = @btime _tr_all_dims($x) # @generated b = @btime tr5($x) # CartesianIndices a ≈ b end # 101.183 ns (0 allocations: 0 bytes) # 41.582 ns (0 allocations: 0 bytes) ```

The big question seems to be whether LinearAlgebra should handle higher-rank objects. dot accepts anything, that might be the only function which goes beyond matrices.

MilesCranmer commented 2 years ago

Nice!

The big question seems to be whether LinearAlgebra should handle higher-rank objects. dot accepts anything, that might be the only function which goes beyond matrices.

I could see this being attractive for operations which have distinct meaning when used on higher dimensional arrays, like dot ( $\checkmark$ ), norm ( $\checkmark$ ), or tr ( $\times$ ) - for those operations you couldn't use a 2D version from LinearAlgebra. But most other operations where you would simply want to vectorize it over a batch axis, the user could just loop it themselves.