JuliaArrays / AxisArrays.jl

Performant arrays where each dimension can have a named axis with values
http://JuliaArrays.github.io/AxisArrays.jl/latest/
Other
200 stars 41 forks source link

Faster version of permutation() and axisnames() #187

Open maxfreu opened 4 years ago

maxfreu commented 4 years ago

Hi! I came up with a faster version of permute(). Maybe it lacks some functionality or generality, but I couldn't come up with such a case yet - can you?

using AxisArrays
using BenchmarkTools

# the wooden hammer
@inline @inbounds function check_duplicates(arr)
    N = length(arr)
    for i in 1:N-1
        for j in i+1:N
            arr[i] != arr[j] || throw(ArgumentError("duplicate"))
        end
    end
    return nothing
end

function foo_perm(to, from)
    length(to) == length(from) || throw(ArgumentError("not same length"))
    res = Vector{Int}(undef, length(from))
    @inbounds for (i,t) in enumerate(to)
        idx = findfirst(from .== t)
        idx != nothing || throw(ArgumentError("a not in b"))
        res[i] = idx
    end
    check_duplicates(res)
    return res
end

to=(:c,:w,:h,:d)
from=(:c,:h,:w,:d)
foo_perm(to, from) == AxisArrays.permutation(to, from)
@btime foo_perm($to, $from)  # 42ns
@btime AxisArrays.permutation($to, $from)  # 307ns

I think axisnames() is typically called with an AxisArray as argument, but I found it a bit slow. Again I came up with another solution which might lack generality, but is faster:

axname(a::AxisArrays.Axis{name}) where name = name
axnames(a::AxisArray) = axname.(a.axes)

a = AxisArray(rand(3,4,5), :c,:h,:w)
axnames(a) == axisnames(a)
@btime axnames($a)  # 340ns
@btime axisnames($a)  # 4us