OpenMendel / SnpArrays.jl

Compressed storage for SNP data
https://openmendel.github.io/SnpArrays.jl/latest
Other
44 stars 9 forks source link

Improve linear algebra #56

Closed kose-y closed 4 years ago

kose-y commented 4 years ago

@biona001 notes that it could be much improved using LoopVectorization.jl. https://github.com/OpenMendel/SnpArrays.jl/issues/51#issuecomment-607462229

kose-y commented 4 years ago

I think it is ideal to have two ways for linear algebra with SnpArray eventually:

Maybe I should try the two approaches for CUDA implementation?

kose-y commented 4 years ago

After #57, it still could be improved with tiled computation.

kose-y commented 4 years ago

New benchmark with loop unrolling on direct linear algebra. It is slightly faster than SnpBitMatrix (need to handle boundaries, though). Maybe we can deprecate SnpBitMatrix? We will see after trying tiled computation next week.

This type of approach is impossible for BitMatrix because starting point of each column is not aligned.

versioninfo()
Julia Version 1.4.1
Commit 381693d3df* (2020-04-14 17:20 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) Silver 4114 CPU @ 2.20GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-8.0.1 (ORCJIT, skylake)
using SnpArrays
const EUR = SnpArray(SnpArrays.datadir("EUR_subset.bed"))
379×54051 SnpArray:
 0x03  0x03  0x03  0x02  0x02  0x03  …  0x03  0x03  0x03  0x03  0x03  0x03
 0x03  0x02  0x03  0x02  0x03  0x03     0x03  0x03  0x03  0x03  0x03  0x03
 0x03  0x03  0x03  0x03  0x03  0x03     0x02  0x02  0x02  0x03  0x03  0x02
 0x03  0x03  0x03  0x00  0x03  0x03     0x03  0x03  0x03  0x03  0x03  0x03
 0x03  0x03  0x03  0x00  0x03  0x03     0x02  0x02  0x02  0x03  0x03  0x03
 0x02  0x03  0x03  0x03  0x03  0x03  …  0x03  0x03  0x03  0x03  0x03  0x02
 0x02  0x03  0x03  0x02  0x02  0x03     0x03  0x03  0x02  0x02  0x03  0x03
 0x02  0x03  0x03  0x03  0x02  0x02     0x03  0x03  0x03  0x03  0x03  0x03
 0x03  0x03  0x03  0x00  0x02  0x03     0x03  0x03  0x03  0x03  0x03  0x03
 0x02  0x03  0x03  0x02  0x03  0x02     0x03  0x03  0x03  0x03  0x03  0x03
 0x03  0x03  0x03  0x02  0x03  0x03  …  0x03  0x03  0x02  0x02  0x03  0x03
 0x03  0x03  0x03  0x02  0x03  0x03     0x03  0x03  0x03  0x03  0x03  0x02
 0x03  0x02  0x03  0x02  0x02  0x03     0x03  0x03  0x03  0x03  0x03  0x03
    ⋮                             ⋮  ⋱     ⋮                             ⋮
 0x03  0x03  0x03  0x00  0x02  0x03     0x03  0x03  0x03  0x03  0x03  0x03
 0x03  0x03  0x03  0x02  0x02  0x03     0x02  0x02  0x02  0x03  0x02  0x03
 0x03  0x03  0x03  0x02  0x02  0x03     0x03  0x03  0x03  0x03  0x03  0x03
 0x02  0x03  0x03  0x02  0x03  0x03  …  0x03  0x03  0x03  0x03  0x03  0x03
 0x03  0x03  0x03  0x00  0x00  0x03     0x02  0x02  0x02  0x03  0x03  0x03
 0x02  0x03  0x03  0x03  0x03  0x03     0x03  0x03  0x03  0x03  0x03  0x03
 0x03  0x03  0x03  0x02  0x03  0x03     0x03  0x03  0x03  0x03  0x03  0x03
 0x03  0x03  0x03  0x02  0x03  0x03     0x03  0x03  0x03  0x03  0x03  0x03
 0x02  0x03  0x03  0x03  0x03  0x03  …  0x03  0x03  0x02  0x02  0x03  0x03
 0x03  0x03  0x03  0x00  0x03  0x03     0x03  0x03  0x03  0x03  0x03  0x03
 0x02  0x03  0x03  0x02  0x00  0x02     0x03  0x03  0x03  0x03  0x03  0x03
 0x03  0x03  0x03  0x02  0x02  0x03     0x03  0x03  0x03  0x03  0x03  0x03

Let's try with EUR data repeated 100 times: 37900 by 54051.

EUR_10 = [EUR; EUR; EUR; EUR; EUR; EUR; EUR; EUR; EUR; EUR]
EUR_100 = [EUR_10; EUR_10; EUR_10; EUR_10; EUR_10; EUR_10; EUR_10; EUR_10; EUR_10; EUR_10];

We create instnaces of SnpLinAlg, SnpBitmatrix and CuSnpArray:

EUR_100_bm = SnpBitMatrix{Float64}(EUR_100; model=ADDITIVE_MODEL, center=false, scale=false)
EUR_100_sla = SnpLinAlg{Float64}(EUR_100; model=ADDITIVE_MODEL, center=false, scale=false)
EUR_100_mat = convert(Matrix{Float64}, EUR_100, model=ADDITIVE_MODEL, center=true, scale=true);
# ENV["JULIA_CUDA_USE_BINARYBUILDER"] = "false"
# using CUDA
# EUR_100_cu = CuSnpArray{Float64}(EUR_100; model=ADDITIVE_MODEL, center=false, scale=false);

$y = Ax$

using LinearAlgebra
using BenchmarkTools
v1 = randn(size(EUR_100, 1))
v2 = randn(size(EUR_100, 2));

With 8-threaded OpenBLAS (standard binary installation):

BLAS.set_num_threads(8)
@benchmark LinearAlgebra.mul!($v1, $EUR_100_mat, $v2)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     436.460 ms (0.00% GC)
  median time:      443.273 ms (0.00% GC)
  mean time:        493.577 ms (0.00% GC)
  maximum time:     697.771 ms (0.00% GC)
  --------------
  samples:          11
  evals/sample:     1

With single-threaded OpenBLAS:

BLAS.set_num_threads(1)
@benchmark LinearAlgebra.mul!($v1, $EUR_100_mat, $v2)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     2.606 s (0.00% GC)
  median time:      2.609 s (0.00% GC)
  mean time:        2.609 s (0.00% GC)
  maximum time:     2.612 s (0.00% GC)
  --------------
  samples:          2
  evals/sample:     1

Direct linear algebra on a SnpArray:

@benchmark LinearAlgebra.mul!($v1, $EUR_100_sla, $v2)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     2.462 s (0.00% GC)
  median time:      2.473 s (0.00% GC)
  mean time:        2.475 s (0.00% GC)
  maximum time:     2.490 s (0.00% GC)
  --------------
  samples:          3
  evals/sample:     1

The below is the benchmark for SnpBitMatrix:

@benchmark (LinearAlgebra.mul!($v1, $EUR_100_bm, $v2))
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.302 s (0.00% GC)
  median time:      1.315 s (0.00% GC)
  mean time:        1.313 s (0.00% GC)
  maximum time:     1.319 s (0.00% GC)
  --------------
  samples:          4
  evals/sample:     1
using LoopVectorization
function _snparray_ax_additive!(out, s::Matrix{UInt8}, v)
    fill!(out, zero(UInt8))
    k = size(s, 1)
    @avx for j ∈ eachindex(v)
        for l in 1:k
            block = s[(j-1)*size(s, 1) + (l-1) + 1]

            # i = 4 * (l-1) + 1
            Aij = block & 3
            out[4*(l-1)+1] += ((Aij >= 2) + (Aij >= 3)) * v[j]

            # i = 4 * (l-1) + 2
            Aij = (block >> 2) & 3
            i = 4 * (l-1) + 2
            out[4*(l-1)+2] += ((Aij >= 2) + (Aij >= 3)) * v[j]

            # i = 4 * (l-1) + 3
            Aij = (block >> 4) & 3
            out[4*(l-1)+3] += ((Aij >= 2) + (Aij >= 3)) * v[j]

            # i = 4 * (l-1) + 4
            Aij = (block >> 6) & 3
            out[4*l] += ((Aij >= 2) + (Aij >= 3)) * v[j]

        end
    end
    out
end
_snparray_ax_additive! (generic function with 1 method)
@benchmark _snparray_ax_additive!($v1, $(EUR_100_sla.s.data), $v2)
BenchmarkTools.Trial: 
  memory estimate:  80 bytes
  allocs estimate:  1
  --------------
  minimum time:     1.246 s (0.00% GC)
  median time:      1.272 s (0.00% GC)
  mean time:        1.275 s (0.00% GC)
  maximum time:     1.309 s (0.00% GC)
  --------------
  samples:          4
  evals/sample:     1
kose-y commented 4 years ago

I will just try tiling direct operations after seeing https://github.com/chriselrod/LoopVectorization.jl/issues/81#issuecomment-663469856

The easiest way to improve performance may be to define a custom BitMatrix/BitArray type that is padded to have its leading axis contain a multiple of 8 elements, so that each column will be byte-aligned and we don't need all the checks and shifts to get the correct loads into a vector we have now.

Another option is that I did add code for aligning columns to LoopVectorization. It didn't seem to help in the cases I tried, but it could help here if we can rely on it, and then have loads without all the checks. This would preclude unrolling the j loop, and thus lead to much worse performance than the former (custom type) option, but with the benefit of possibly improving multi-dimensional BitArray performance in general.

Plink format itself is already a byte-aligned data structure.

For some reason, tiling did not accelerate direct operation much. I will just continue with tiling BitMatrix.

kose-y commented 4 years ago

Here are some benchmarks with loop unrolling and tiling. I will return to this next week with boundary treatment.

Direct $Ax$: loop-unrolling and tiling

using LoopVectorization
function _snparray_ax_additive!(out, s::Matrix{UInt8}, v)
    fill!(out, zero(Float64))
    k = size(s, 1)
    @avx for j ∈ eachindex(v)
        for l in 1:k
            block = s[(j-1)*k + (l-1) + 1]

            # i = 4 * (l-1) + 1
            Aij_1 = block & 3
            out[4*(l-1)+1] += ((Aij_1 >= 2) + (Aij_1 >= 3)) * v[j]

            # i = 4 * (l-1) + 2
            Aij_2 = (block >> 2) & 3
            i = 4 * (l-1) + 2
            out[4*(l-1)+2] += ((Aij_2 >= 2) + (Aij_2 >= 3)) * v[j]

            # i = 4 * (l-1) + 3
            Aij_3 = (block >> 4) & 3
            out[4*(l-1)+3] += ((Aij_3 >= 2) + (Aij_3 >= 3)) * v[j]

            # i = 4 * (l-1) + 4
            Aij_4 = (block >> 6) & 3
            out[4*l] += ((Aij_4 >= 2) + (Aij_4 >= 3)) * v[j]

        end
    end
    out
end
_snparray_ax_additive! (generic function with 1 method)
using VectorizationBase, LoopVectorization
vstep = 1024
hstep = 1024
vstep_log2 = 10
hstep_log2 = 10
function _snparray_ax_additive_step!(out, s, v)
    @avx for j ∈ 1:hstep
        for l in 1:vstep
            block = s[l, j]

            # i = 4 * (l-1) + 1
            Aij_1 = block & 3
            out[4*(l-1)+1] += ((Aij_1 >= 2) + (Aij_1 >= 3)) * v[j]

            # i = 4 * (l-1) + 2
            Aij_2 = (block >> 2) & 3
            i = 4 * (l-1) + 2
            out[4*(l-1)+2] += ((Aij_2 >= 2) + (Aij_2 >= 3)) * v[j]

            # i = 4 * (l-1) + 3
            Aij_3 = (block >> 4) & 3
            out[4*(l-1)+3] += ((Aij_3 >= 2) + (Aij_3 >= 3)) * v[j]

            # i = 4 * (l-1) + 4
            Aij_4 = (block >> 6) & 3
            out[4*l] += ((Aij_4 >= 2) + (Aij_4 >= 3)) * v[j]
        end
    end
    out
end

function _snparray_ax_additive_tile!(c, A, b)
    fill!(c, zero(UInt8))

    M, N = size(A)
    Miter = M >>> vstep_log2 # fast div(M, 512)
    Mrem = M & (vstep-1) # fast rem(M, 512)
    Niter = N >>> hstep_log2
    Nrem = N & (hstep-1)
    GC.@preserve c A b for n in 0:Niter-1
        for m in 0:Miter-1
            _snparray_ax_additive_step!(
                gesp(stridedpointer(c), ((4 * vstep)*m,)),
                gesp(stridedpointer(A), (vstep*m, hstep*n)),
                gesp(stridedpointer(b), (hstep*n,))
            )
        end
        # TODO: handle mrem
    end
    # TODO: handle nrem
end
_snparray_ax_additive_tile! (generic function with 1 method)
using Random, Test
n = 1<<14
a = rand(UInt8, n, n)
v1 = zeros(n << 2)
v1_ = zeros(n << 2)
v1__ = zeros(n << 2);

v2 = rand(Float64, n);

Let's check correctness first:

using SnpArrays
SnpArrays._snparray_ax_additive!(v1, a, v2);

_snparray_ax_additive!(v1_, a, v2);

_snparray_ax_additive_tile!(v1__, a, v2);
@test isapprox(v1, v1_) && isapprox(v1, v1__)
true
using BenchmarkTools
@benchmark SnpArrays._snparray_ax_additive!(v1, a, v2)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.832 s (0.00% GC)
  median time:      1.850 s (0.00% GC)
  mean time:        1.861 s (0.00% GC)
  maximum time:     1.902 s (0.00% GC)
  --------------
  samples:          3
  evals/sample:     1
@benchmark _snparray_ax_additive!(v1_, a, v2)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     528.025 ms (0.00% GC)
  median time:      530.024 ms (0.00% GC)
  mean time:        530.313 ms (0.00% GC)
  maximum time:     533.325 ms (0.00% GC)
  --------------
  samples:          10
  evals/sample:     1
@benchmark _snparray_ax_additive_tile!(v1__, a, v2)
BenchmarkTools.Trial: 
  memory estimate:  148.13 KiB
  allocs estimate:  7655
  --------------
  minimum time:     517.248 ms (0.00% GC)
  median time:      519.453 ms (0.00% GC)
  mean time:        519.346 ms (0.00% GC)
  maximum time:     523.305 ms (0.00% GC)
  --------------
  samples:          10
  evals/sample:     1

Loop unrolling certainly accelerates the Ax operation, but the operation is not benefiting much from tiling.

Direct $A^T x$: loop-unrolling and tiling

function _snparray_atx_additive!(out, s::Matrix{UInt8}, v)
    fill!(out, zero(Float64))
    k = size(s, 1)
    @avx for i ∈ eachindex(out)
        for l in 1:k
            block = s[(i-1)*k + (l-1) + 1]

            # j = 4 * (l-1) + 1
            Aij_1 = block & 3
            out[i] += ((Aij_1 >= 2) + (Aij_1 >= 3)) * v[4*(l-1)+1]

            # j = 4 * (l-1) + 2
            Aij_2 = (block >> 2) & 3
            i = 4 * (l-1) + 2
            out[i] += ((Aij_2 >= 2) + (Aij_2 >= 3)) * v[4*(l-1)+2]

            # j = 4 * (l-1) + 3
            Aij_3 = (block >> 4) & 3
            out[i] += ((Aij_3 >= 2) + (Aij_3 >= 3)) * v[4*(l-1)+3]

            # j = 4 * (l-1) + 4
            Aij_4 = (block >> 6) & 3
            out[i] += ((Aij_4 >= 2) + (Aij_4 >= 3)) * v[4*(l-1)+4]

        end
    end
    out
end
_snparray_atx_additive! (generic function with 1 method)
using VectorizationBase, LoopVectorization
vstep = 8192
hstep = 8192
vstep_log2 = 13
hstep_log2 = 13
function _snparray_atx_additive_step!(out, s, v)
    @avx for i ∈ 1:hstep
        for l in 1:vstep
            block = s[l, i]

            # i = 4 * (l-1) + 1
            Aij_1 = block & 3
            out[i] += ((Aij_1 >= 2) + (Aij_1 >= 3)) * v[4*(l-1)+1]

            # i = 4 * (l-1) + 2
            Aij_2 = (block >> 2) & 3
            out[i] += ((Aij_2 >= 2) + (Aij_2 >= 3)) * v[4*(l-1)+2]

            # i = 4 * (l-1) + 3
            Aij_3 = (block >> 4) & 3
            out[i] += ((Aij_3 >= 2) + (Aij_3 >= 3)) * v[4*(l-1)+3]

            # i = 4 * (l-1) + 4
            Aij_4 = (block >> 6) & 3
            out[i] += ((Aij_4 >= 2) + (Aij_4 >= 3)) * v[4*(l-1)+4]
        end
    end
    out
end

function _snparray_atx_additive_tile!(c, A, b)
    fill!(c, zero(UInt8))

    M, N = size(A)
    Miter = M >>> vstep_log2 # fast div(M, 512)
    Mrem = M & (vstep-1) # fast rem(M, 512)
    Niter = N >>> hstep_log2
    Nrem = N & (hstep-1)
    GC.@preserve c A b for n in 0:Niter-1
        for m in 0:Miter-1
            _snparray_atx_additive_step!(
                gesp(stridedpointer(c), ((hstep)*n,)),
                gesp(stridedpointer(A), (vstep*m, hstep*n)),
                gesp(stridedpointer(b), (4*vstep*m,))
            )
        end
        # TODO: handle mrem
    end
    # TODO: handle nrem
end
_snparray_atx_additive_tile! (generic function with 1 method)
using Random, Test
n = 1<<14
a = rand(UInt8, n, n)
v1 = rand(n << 2)

v2 = zeros(Float64, n);
v2_ = zeros(Float64, n);
v2__ = zeros(Float64, n);
SnpArrays._snparray_atx_additive!(v2, a, v1);
_snparray_atx_additive!(v2_, a, v1)
_snparray_atx_additive_tile!(v2__, a, v1);
@test isapprox(v2, v2_) && isapprox(v2, v2__)
[32m[1mTest Passed[22m[39m
@benchmark SnpArrays._snparray_atx_additive!(v2, a, v1)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     962.160 ms (0.00% GC)
  median time:      963.807 ms (0.00% GC)
  mean time:        967.119 ms (0.00% GC)
  maximum time:     983.573 ms (0.00% GC)
  --------------
  samples:          6
  evals/sample:     1
@benchmark _snparray_atx_additive!(v2_, a, v1)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     615.608 ms (0.00% GC)
  median time:      617.361 ms (0.00% GC)
  mean time:        618.730 ms (0.00% GC)
  maximum time:     624.571 ms (0.00% GC)
  --------------
  samples:          9
  evals/sample:     1
@benchmark _snparray_atx_additive_tile!(v2__, a, v1)
BenchmarkTools.Trial: 
  memory estimate:  2.44 KiB
  allocs estimate:  123
  --------------
  minimum time:     338.661 ms (0.00% GC)
  median time:      340.334 ms (0.00% GC)
  mean time:        353.059 ms (0.00% GC)
  maximum time:     444.355 ms (0.00% GC)
  --------------
  samples:          15
  evals/sample:     1

This time, tiling certainly accelerates the operation.

BitMatrix $Ax$: tiling

using VectorizationBase: gesp, stridedpointer
function gemv_avx!(c, A, b)
    @avx for j in 1:size(A, 2), i in 1:size(A, 1)
        c[i] += A[i, j] * b[j]
    end
end
vstep = 512
hstep = 512
vstep_log2 = 9
hstep_log2 = 9
function gemv_avx_step!(c, A, b)
    @avx for j in 1:hstep, i in 1:vstep
        c[i] += A[i, j] * b[j]
    end
end
function gemv_tile!(c, A, b)
    M, N = size(A)
    Miter = M >>> vstep_log2 # fast div(M, 512)
    Mrem = M & (vstep-1) # fast rem(M, 512)
    Niter = N >>> hstep_log2
    Nrem = N & (hstep-1)
    GC.@preserve c A b for n in 0:Niter-1
        for m in 0:Miter-1
            gemv_avx_step!(
                gesp(stridedpointer(c), (vstep*m,)),
                gesp(stridedpointer(A), (vstep*m, hstep*n)),
                gesp(stridedpointer(b), (hstep*n,))
            )
        end
        # TODO: handle mrem
    end
    # TODO: handle nrem
end
gemv_tile! (generic function with 1 method)
n = 1 << 14
A = bitrand(n<<1, n);
b = rand(n); c1 = zeros(2n); c2 = zeros(2n);
gemv_tile!(c1, A, b);
gemv_avx!(c2, A, b);
@test c1 ≈ c2
[32m[1mTest Passed[22m[39m
@benchmark gemv_avx!($c2, $A, $b)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     312.292 ms (0.00% GC)
  median time:      318.512 ms (0.00% GC)
  mean time:        319.752 ms (0.00% GC)
  maximum time:     343.857 ms (0.00% GC)
  --------------
  samples:          16
  evals/sample:     1
@benchmark gemv_tile!($c1, $A, $b)
BenchmarkTools.Trial: 
  memory estimate:  1023.09 KiB
  allocs estimate:  49029
  --------------
  minimum time:     155.368 ms (0.00% GC)
  median time:      157.639 ms (0.00% GC)
  mean time:        167.359 ms (0.00% GC)
  maximum time:     212.287 ms (0.00% GC)
  --------------
  samples:          30
  evals/sample:     1

As noted before, tiling accelerates bitmatrix gemv.

BitMatrix $A^Tx$: tiling

using VectorizationBase: gesp, stridedpointer
using LoopVectorization
function gemv_t_avx!(c, A, b)
    fill!(c, 0)
    @avx for j in 1:size(A, 1), i in 1:size(A, 2)
        c[i] += A[j, i] * b[j]
    end
end
vstep = 512
hstep = 512
vstep_log2 = 9
hstep_log2 = 9
function gemv_t_avx_step!(c, A, b)
    @avx for j in 1:hstep, i in 1:vstep
        c[i] += A[j, i] * b[j]
    end
end
function gemv_t_tile!(c, A, b)
    fill!(c, 0)
    M, N = size(A)
    Miter = M >>> vstep_log2 # fast div(M, 512)
    Mrem = M & (vstep-1) # fast rem(M, 512)
    Niter = N >>> hstep_log2
    Nrem = N & (hstep-1)
    GC.@preserve c A b for n in 0:Niter-1
        for m in 0:Miter-1
            gemv_t_avx_step!(
                gesp(stridedpointer(c), (hstep*n,)),
                gesp(stridedpointer(A), (vstep*m, hstep*n)),
                gesp(stridedpointer(b), (vstep*m,))
            )
        end
        # TODO: handle mrem
    end
    # TODO: handle nrem
end
gemv_t_tile! (generic function with 1 method)
using Random
n = 1 << 14
A = bitrand(n<<1, n);
b1 = rand(n); b2 = rand(n); c1 = rand(2n); c2 = rand(2n);

gemv_t_avx!(b2, A, c1);
gemv_t_tile!(b1, A, c1);
@test b1 ≈ b2
[32m[1mTest Passed[22m[39m
using BenchmarkTools
@benchmark gemv_t_avx!(b2, A, c1)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     77.863 ms (0.00% GC)
  median time:      80.086 ms (0.00% GC)
  mean time:        80.263 ms (0.00% GC)
  maximum time:     84.744 ms (0.00% GC)
  --------------
  samples:          63
  evals/sample:     1
@benchmark gemv_t_tile!(b1, A, c1)
BenchmarkTools.Trial: 
  memory estimate:  1023.09 KiB
  allocs estimate:  49029
  --------------
  minimum time:     98.377 ms (0.00% GC)
  median time:      100.314 ms (0.00% GC)
  mean time:        101.601 ms (0.00% GC)
  maximum time:     121.993 ms (0.00% GC)
  --------------
  samples:          50
  evals/sample:     1

Tiling is not very helpful in this direction.

kose-y commented 4 years ago

Updated with boundary treatment.

Direct $Ax$: loop-unrolling and tiling

using LoopVectorization
function _snparray_ax_additive_rem!(out, s, v)
    maxp = length(out)
    @avx for j in eachindex(v)
        block = s[1, j]
        for p in 1:maxp
            Aij = (block >> (2*(p-1))) & 3
            out[p] += ((Aij >= 2) + (Aij == 3)) * v[j]
        end
    end
end
function _snparray_ax_additive!(out, s, v, rows, cols)
    #fill!(out, zero(Float64))
    k = rows >> 2
    rem = rows & 3

    @avx for j ∈ 1:cols
        for l in 1:k
            block = s[l, j]

            for p in 1:4
                Aij = (block >> (2 *(p-1))) & 3
                out[4*(l-1)+p] += ((Aij >= 2) + (Aij == 3)) * v[j]
            end

        end
    end
    if rem != 0
        _snparray_ax_additive_rem!(@view(out[4k+1:end]), @view(s[k+1:k+1, :]), v)
    end
    out
end

function _snparray_ax_additive!(out, s, v)
    _snparray_ax_additive!(out, s, v, length(out), length(v))
end
using VectorizationBase, LoopVectorization

function _snparray_ax_additive_tile!(c, A, b)
    vstep = 1024
    hstep = 1024
    vstep_log2 = 10
    hstep_log2 = 10

    fill!(c, zero(UInt8))

    M = length(c) >> 2
    N = size(A, 2)
    Miter = M >>> vstep_log2 # fast div(M, 512)
    Mrem = M & (vstep-1) # fast rem(M, 512)
    Niter = N >>> hstep_log2
    Nrem = N & (hstep-1)
    GC.@preserve c A b for n in 0:Niter-1
        for m in 0:Miter-1
            _snparray_ax_additive!(
                gesp(stridedpointer(c), ((4 * vstep)*m,)),
                gesp(stridedpointer(A), (vstep*m, hstep*n)),
                gesp(stridedpointer(b), (hstep*n,)),
                vstep << 2,
                hstep
            )
        end
        if Mrem != 0
            _snparray_ax_additive!(@view(c[(4*vstep)*Miter+1:end]), 
                @view(A[vstep*(Miter)+1:end, hstep*n+1:hstep*(n+1)]),
                @view(b[hstep*n+1:hstep*(n+1)]),
                length(c) - 4*vstep*Miter,
                hstep
            )
        end
    end
    if Nrem != 0
        _snparray_ax_additive!(c, @view(A[:, (hstep*Niter+1):end]), @view(b[hstep*Niter+1:end]),
            length(c), Nrem
        )
    end
end
_snparray_ax_additive_tile! (generic function with 1 method)
using Random, Test
n = 1<<14
a = rand(UInt8, n - 5, n+1)
v1 = zeros(n << 2 - 22)
v1_ = zeros(n << 2 - 22)
v1__ = zeros(n << 2 - 22);

v2 = rand(Float64, n+1);

Let's check correctness first:

using SnpArrays
fill!(v1, 0)
SnpArrays._snparray_ax_additive!(v1, a, v2);

_snparray_ax_additive!(v1_, a, v2);
_snparray_ax_additive_tile!(v1__, a, v2);
@test isapprox(v1, v1_) 
@test isapprox(v1_, v1__)
[32m[1mTest Passed[22m[39m
using BenchmarkTools
@benchmark SnpArrays._snparray_ax_additive!(v1, a, v2)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.275 s (0.00% GC)
  median time:      1.278 s (0.00% GC)
  mean time:        1.278 s (0.00% GC)
  maximum time:     1.281 s (0.00% GC)
  --------------
  samples:          4
  evals/sample:     1
@benchmark _snparray_ax_additive!(v1_, a, v2)
BenchmarkTools.Trial: 
  memory estimate:  112 bytes
  allocs estimate:  2
  --------------
  minimum time:     673.050 ms (0.00% GC)
  median time:      675.058 ms (0.00% GC)
  mean time:        683.039 ms (0.00% GC)
  maximum time:     735.571 ms (0.00% GC)
  --------------
  samples:          8
  evals/sample:     1
@benchmark _snparray_ax_additive_tile!(v1__, a, v2)
BenchmarkTools.Trial: 
  memory estimate:  111.78 KiB
  allocs estimate:  5651
  --------------
  minimum time:     530.016 ms (0.00% GC)
  median time:      531.549 ms (0.00% GC)
  mean time:        540.101 ms (0.00% GC)
  maximum time:     571.763 ms (0.00% GC)
  --------------
  samples:          10
  evals/sample:     1

Ax operation benefits from both loop-unrolling and tiling.

Direct $A^T x$: loop-unrolling and tiling

function _snparray_atx_additive_rem!(out, s, v)
    maxp = length(v)
    @avx for i in eachindex(out)
        block = s[1, i]
        for p in 1:maxp
            Aij = (block >> (2*(p-1))) & 3
            out[i] += ((Aij >= 2) + (Aij == 3)) * v[p]
        end
    end
end

function _snparray_atx_additive!(out, s, v, rows, cols)
    #fill!(out, zero(Float64))
    k = rows >> 2
    rem = rows & 3

    @avx for i ∈ 1:cols
        for l in 1:k
            block = s[l, i]

            for p in 1:4
                Aij = (block >> (2 *(p-1))) & 3
                out[i] += ((Aij >= 2) + (Aij == 3)) * v[4*(l-1)+p]
            end

        end
    end
    if rem != 0
        _snparray_atx_additive_rem!(out, @view(s[k+1:k+1, :]), @view(v[4k+1:end]))
    end
    out
end

function _snparray_atx_additive!(out, s, v)
    _snparray_atx_additive!(out, s, v, length(v), length(out))
end
_snparray_atx_additive! (generic function with 2 methods)
using VectorizationBase, LoopVectorization
vstep = 1024
hstep = 1024
vstep_log2 = 10
hstep_log2 = 10

function _snparray_atx_additive_tile!(c, A, b)
    fill!(c, zero(UInt8))

    M = length(b) >> 2
    N = length(c)
    Miter = M >>> vstep_log2 # fast div(M, 512)
    Mrem = M & (vstep-1) # fast rem(M, 512)
    Niter = N >>> hstep_log2
    Nrem = N & (hstep-1)
    GC.@preserve c A b for n in 0:Niter-1
        for m in 0:Miter-1
            _snparray_atx_additive!(
                gesp(stridedpointer(c), ((hstep)*n,)),
                gesp(stridedpointer(A), (vstep*m, hstep*n)),
                gesp(stridedpointer(b), (4*vstep*m,)),
                vstep << 2,
                hstep
            )
        end
        if Mrem != 0
            _snparray_atx_additive!(@view(c[hstep*n+1:hstep*(n+1)]), 
                @view(A[vstep*(Miter)+1:end, hstep*n+1:hstep*(n+1)]),
                @view(b[(4*vstep)*Miter+1:end]),
                length(b) - 4*vstep*Miter,
                hstep
            )
        end
    end
    if Nrem != 0
        _snparray_atx_additive!(@view(c[hstep*Niter+1:end]), 
            @view(A[:, (hstep*Niter+1):end]), 
            b, length(b), Nrem
        )
    end
end
_snparray_atx_additive_tile! (generic function with 1 method)
using Random, Test
n = 1<<14
a = rand(UInt8, n-5, n+1)
v1 = rand(n << 2 - 22)

v2 = zeros(Float64, n+1);
v2_ = zeros(Float64, n+1);
v2__ = zeros(Float64, n+1);
SnpArrays._snparray_atx_additive!(v2, a, v1);
_snparray_atx_additive!(v2_, a, v1)
_snparray_atx_additive_tile!(v2__, a, v1);
@test isapprox(v2, v2_) && isapprox(v2, v2__)
[32m[1mTest Passed[22m[39m
@benchmark SnpArrays._snparray_atx_additive!(v2, a, v1)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     786.159 ms (0.00% GC)
  median time:      793.560 ms (0.00% GC)
  mean time:        801.601 ms (0.00% GC)
  maximum time:     844.316 ms (0.00% GC)
  --------------
  samples:          7
  evals/sample:     1
@benchmark _snparray_atx_additive!(v2_, a, v1)
BenchmarkTools.Trial: 
  memory estimate:  112 bytes
  allocs estimate:  2
  --------------
  minimum time:     341.200 ms (0.00% GC)
  median time:      342.957 ms (0.00% GC)
  mean time:        348.617 ms (0.00% GC)
  maximum time:     395.670 ms (0.00% GC)
  --------------
  samples:          15
  evals/sample:     1
@benchmark _snparray_atx_additive_tile!(v2__, a, v1)
BenchmarkTools.Trial: 
  memory estimate:  111.78 KiB
  allocs estimate:  5651
  --------------
  minimum time:     488.820 ms (0.00% GC)
  median time:      521.252 ms (0.00% GC)
  mean time:        521.462 ms (0.00% GC)
  maximum time:     544.272 ms (0.00% GC)
  --------------
  samples:          10
  evals/sample:     1

This time, tiling certainly accelerates the operation.

BitMatrix $Ax$: tiling

using VectorizationBase: gesp, stridedpointer
function gemv_avx!(c, A, b)
    fill!(c, 0)
    @avx for j in 1:size(A, 2), i in 1:size(A, 1)
        c[i] += A[i, j] * b[j]
    end
end
function gemv_avx_sized!(c, A, b, rows, cols)
    @avx for j in 1:cols, i in 1:rows
        c[i] += A[i, j] * b[j]
    end
end

function gemv_avx_step!(c, A, b)
    @avx for j in 1:hstep, i in 1:vstep
        c[i] += A[i, j] * b[j]
    end
end
function gemv_tile!(c, A, b)
    vstep = 512
    hstep = 512
    vstep_log2 = 9
    hstep_log2 = 9
    M, N = size(A)
    Miter = M >>> vstep_log2 # fast div(M, 512)
    Mrem = M & (vstep-1) # fast rem(M, 512)
    Niter = N >>> hstep_log2
    Nrem = N & (hstep-1)
    GC.@preserve c A b for n in 0:Niter-1
        for m in 0:Miter-1
            gemv_avx_sized!(
                gesp(stridedpointer(c), (vstep*m,)),
                gesp(stridedpointer(A), (vstep*m, hstep*n)),
                gesp(stridedpointer(b), (hstep*n,)),
                vstep, hstep
            )
        end
        if Mrem != 0
            gemv_avx_sized!(
                gesp(stridedpointer(c), (vstep*Miter,)), 
                gesp(stridedpointer(A), (vstep*Miter, hstep*n)),
                gesp(stridedpointer(b), (hstep*n,)),
                Mrem, hstep
            )
        end
    end
    if Nrem != 0
        gemv_avx_sized!(
            gesp(stridedpointer(c), (0,)),
            gesp(stridedpointer(A), (0, hstep*Niter)),
            gesp(stridedpointer(b), (hstep*Niter,)),
            length(c),
            Nrem
        )
    end
end
gemv_tile! (generic function with 1 method)
n = 1 << 14
A = bitrand(4n-22, n+1);
b = rand(n+1); c1 = zeros(4n-22); c2 = zeros(4n-22);
gemv_tile!(c1, A, b);
gemv_avx!(c2, A, b);
@test c1 ≈ c2
[32m[1mTest Passed[22m[39m
@benchmark gemv_avx!($c2, $A, $b)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     350.635 ms (0.00% GC)
  median time:      352.818 ms (0.00% GC)
  mean time:        355.604 ms (0.00% GC)
  maximum time:     394.634 ms (0.00% GC)
  --------------
  samples:          15
  evals/sample:     1
@benchmark gemv_tile!($c1, $A, $b)
BenchmarkTools.Trial: 
  memory estimate:  1.99 MiB
  allocs estimate:  97841
  --------------
  minimum time:     290.880 ms (0.00% GC)
  median time:      293.888 ms (0.00% GC)
  mean time:        294.088 ms (0.00% GC)
  maximum time:     298.778 ms (0.00% GC)
  --------------
  samples:          18
  evals/sample:     1

As noted before, tiling accelerates bitmatrix gemv. Not as fast as well-aligned (multiple of 8) cases, though.

BitMatrix $A^Tx$: tiling

using VectorizationBase: gesp, stridedpointer
using LoopVectorization
function gemv_t_avx!(c, A, b)
    fill!(c, 0)
    @avx for j in 1:size(A, 1), i in 1:size(A, 2)
        c[i] += A[j, i] * b[j]
    end
end

function gemv_t_avx_sized!(c, A, b, rows, cols)
    @avx for j in 1:rows, i in 1:cols
        c[i] += A[j, i] * b[j]
    end
end
vstep = 512
hstep = 512
vstep_log2 = 9
hstep_log2 = 9
function gemv_t_avx_step!(c, A, b)
    @avx for j in 1:hstep, i in 1:vstep
        c[i] += A[j, i] * b[j]
    end
end
function gemv_t_tile!(c, A, b)
    fill!(c, 0)
    M, N = size(A)
    Miter = M >>> vstep_log2 # fast div(M, 512)
    Mrem = M & (vstep-1) # fast rem(M, 512)
    Niter = N >>> hstep_log2
    Nrem = N & (hstep-1)
    GC.@preserve c A b for n in 0:Niter-1
        for m in 0:Miter-1
            gemv_t_avx_step!(
                gesp(stridedpointer(c), (hstep*n,)),
                gesp(stridedpointer(A), (vstep*m, hstep*n)),
                gesp(stridedpointer(b), (vstep*m,))
            )
        end
        if Mrem != 0
            gemv_t_avx_sized!(
                gesp(stridedpointer(c), (hstep*n,)), 
                gesp(stridedpointer(A), (vstep*Miter, hstep*n)),
                gesp(stridedpointer(b), (vstep*Miter,)),
                Mrem, hstep
            )
        end
    end
    if Nrem != 0
        gemv_t_avx_sized!(
            gesp(stridedpointer(c), (hstep*Niter,)),
            gesp(stridedpointer(A), (0, hstep*Niter)),
            gesp(stridedpointer(b), (0,)),
            length(b),
            Nrem
        )
    end
end
gemv_t_tile! (generic function with 1 method)
using Random
n = 1 << 14
A = bitrand(4n-22, n+1);
b1 = rand(n+1); b2 = rand(n+1); c1 = rand(4n-22); c2 = rand(4n-22);

gemv_t_avx!(b1, A, c1);
gemv_t_tile!(b2, A, c1);
@test b1 ≈ b2
[32m[1mTest Passed[22m[39m
using BenchmarkTools
@benchmark gemv_t_avx!(b2, A, c1)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     160.416 ms (0.00% GC)
  median time:      161.961 ms (0.00% GC)
  mean time:        163.193 ms (0.00% GC)
  maximum time:     173.785 ms (0.00% GC)
  --------------
  samples:          31
  evals/sample:     1
@benchmark gemv_t_tile!(b1, A, c1)
BenchmarkTools.Trial: 
  memory estimate:  1.99 MiB
  allocs estimate:  97841
  --------------
  minimum time:     186.909 ms (0.00% GC)
  median time:      188.484 ms (0.00% GC)
  mean time:        190.167 ms (0.00% GC)
  maximum time:     211.447 ms (0.00% GC)
  --------------
  samples:          27
  evals/sample:     1

Tiling is not very helpful for this direction.