Open MilesCranmer opened 2 years ago
Here's a significantly faster version:
@generated function _tr_all_dims(x::AbstractArray{T, N}) where {T,N}
x_part = :(x[$(fill(:i, N)...)])
quote
out = zero($T)
for i in axes(x, 1)
out += $x_part
end
out
end
end
function tr(x::AbstractArray; dims)
mapslices(_tr_all_dims, x; dims=dims)
end
although mapslices
is quite slow for some reason.
Bit faster, but still slower than tr
@generated function _tr_dims(x::AbstractArray{T,N}; dims) where {T,N}
# val_dims is a tuple of Val(i), Val(j), etc.
dims = collect(val.parameters[1] for val in dims.parameters)
indices = [
(j in dims) ? :(i) : :(:) for j in 1:N
]
x_part = :(x[$(indices...)])
summation = if :(:) in indices
:(out .+= $x_part)
else
:(out += $x_part)
end
quote
i = first(axes(x, 1))
out = zero($x_part)
for i in axes(x, 1)
$summation
end
out
end
end
function LinearAlgebra.tr(x; dims)
dims = Tuple(collect(Val(dim) for dim in dims))
return _tr_dims(x; dims=dims)
end
Actually for big arrays, they are basically comparable. I think the multi-dimensional trace simply has more things to do at startup.
Perhaps you will want to replace fill(i, N)
and alike by ntuple(_ -> i, N)
. That makes the allocations disappear.
Thanks. Here's the updated performances on a second-order tensor. The baseline, with normal tr
, is 1.18 us
function tr1(x::AbstractArray{T,N}; dims) where {T,N}
return mapslices(_x -> sum(i -> x[ntuple(_ -> i, N)...], axes(_x, 1)), x; dims=dims)
end
@generated function _tr_all_dims(x::AbstractArray{T, N}) where {T,N}
x_part = :(x[$(fill(:i, N)...)])
quote
out = zero($T)
for i in axes(x, 1)
out += $x_part
end
out
end
end
function tr2(x::AbstractArray; dims)
mapslices(_tr_all_dims, x; dims=dims)
end
@generated function _tr3_dims(x::AbstractArray{T,N}; dims) where {T,N}
# val_dims is a tuple of Val(i), Val(j), etc.
dims = collect(val.parameters[1] for val in dims.parameters)
indices = [
(j in dims) ? :(i) : :(:) for j in 1:N
]
x_part = :(x[$(indices...)])
summation = if :(:) in indices
:(out .+= $x_part)
else
:(out += $x_part)
end
quote
i = first(axes(x, 1))
out = zero($x_part)
for i in axes(x, 1)
$summation
end
out
end
end
function tr3(x; dims)
return _tr3_dims(x; dims=Tuple(collect(Val(dim) for dim in dims)))
end
Val(i)
rather than creating)
function tr4(x::AbstractArray{T,N}; dims) where {T,N}
possible_dims = ntuple(i -> Val(i), N)
selected_dims = Tuple(collect(possible_dims[dim] for dim in dims))
return _tr3_dims(x; dims=selected_dims)
end
@dkarrasch let me know what you think, and I can make a PR.
I'm finding the metaprogramming stuff quite hard to read. A simple
using LinearAlgebra
import LinearAlgebra: checksquare
function checksquare(A::AbstractArray{<:Any,N}) where {N}
sz = size(A)
all(==(sz[1]), sz) || throw(DimensionMismatch("array is not square: dimensions are $(size(A))"))
sz[1]
end
function mytr(A::Array{T,N}) where {T,N}
n = checksquare(A)
t = zero(T)
@inbounds @simd for i in 1:n
t += A[ntuple(_ -> i, N)...]
end
t
end
performs even better (due to the @simd
annotation, which we should also invest on the current method in dense.jl
; done in JuliaLang/julia#47585) than LinearAlgebra.tr
on matrices, and its performance depends only on the number of diagonal elements. I tested on 400x400 and 400x400x400 arrays. As for the batched computation, mapslices
seems to add massive overhead, but on the other hand I'm not sure that if we have this machinery, then we cook up some complicated super-specialized code here just for tr
? @mbauman, do you have some good advice here?
It's not so complicated to make CartesianIndices
do this.
tr5(A::AbstractArray; dims=:) = _mytr(dims, A)
function _mytr(dims::Tuple{Integer, Vararg{Integer}}, A::AbstractArray)
dimaxes = map(d -> axes(A,d), dims)
allequal(dimaxes) || _tr_error(dimaxes)
mask = ntuple(d -> !(d in dims), ndims(A))
B = similar(A, ifelse.(mask, axes(A), (Base.OneTo(1),)))
for I in CartesianIndices(B)
t = zero(eltype(A))
@inbounds @simd for j in first(dimaxes)
K = CartesianIndex(ifelse.(mask, Tuple(I), j))
t += A[K]
end
B[I] = t
end
dropdims(B; dims)
end
@noinline _tr_error(ax) = throw(DimensionMismatch(
"traced dimensions must agree, but got $(ax)"))
The big question seems to be whether LinearAlgebra should handle higher-rank objects. dot
accepts anything, that might be the only function which goes beyond matrices.
Nice!
The big question seems to be whether LinearAlgebra should handle higher-rank objects.
dot
accepts anything, that might be the only function which goes beyond matrices.
I could see this being attractive for operations which have distinct meaning when used on higher dimensional arrays, like dot
( $\checkmark$ ), norm
( $\checkmark$ ), or tr
( $\times$ ) - for those operations you couldn't use a 2D version from LinearAlgebra. But most other operations where you would simply want to vectorize it over a batch axis, the user could just loop it themselves.
Right now,
LinearAlgebra.tr
only traces second-order tensors. However, it can be very useful to trace along a specific set of axes (e.g., for batched calculations), as well as ND traces (e.g.,x[i, i, i]
). This can be done with Tullio.jl for static dimension specification, but when you need to specify which axes are being traced, it requires some meta-programming. So it would be nicer iftr
itself could take in adims
argument, and also handle more than two axes.Therefore, I propose to extend
tr
to allow for adims
argument, which can be used for arbitrary dimension arrays. The following code gives a simple working implementation:So now we can do things like:
Let me know what you think.
Also - this implementation (with splats) is very slow, so we'd probably want to compile-in specialization for the dimensionality of the array.
Thanks, Miles