Jutho / TensorOperations.jl

Julia package for tensor contractions and related operations
https://jutho.github.io/TensorOperations.jl/stable/
Other
452 stars 56 forks source link

Enable multithreads when doing the permutedims in the TTGT algorithms #145

Closed jemiryguo closed 1 year ago

jemiryguo commented 1 year ago

TensorOperations use the transpose-transpose-gemm-transpose (TTGT) algorithm to calculate a tensor contraction. When the thread of BLAS is set to a large number (e.g. 48), most of the time during tensor contraction is spent on transposing, as the current implementation does not utilize multithreads when transposing. After boosting transpose in TTGT with multithreads using @strided macro in Strided.jl which is another package by @Jutho , the total time is greatly reduced.

# `julia` is luanched with multithreads `julia -t 48`.
using LinearAlgebra, MKL, Strided, BenchmarkTools, TensorOperations

function f()
    BLAS.set_num_threads(48)
    Strided.set_num_threads(48)
    a = rand(100, 1000, 1000)
    b = rand(1000, 1000, 10)
    c = rand(10, 100, 1000, 1000)
    @btime @tensor $c[:] = $a[-2, 1, -4] * $b[-3, 1, -1]
    @btime begin
        aa = similar($a)
        @strided aa .= permutedims($a, (1, 3, 2))
        bb = similar($b)
        @strided bb .= permutedims($b, (2, 1, 3))
        cc = Array{Float64}(undef, 100, 1000, 1000, 10)
        BLAS.gemm!('N', 'N', 1.0, reshape(aa, 100000, 1000), reshape(bb, 1000, 10000), 0.0, reshape(cc, 100000, 10000))
        @strided $c .= permutedims(cc, (4, 1, 3, 2))
    end
    nothing
end
f()

Output is

4.124 s (7 allocations: 8.27 GiB)
1.481 s (2643 allocations: 8.27 GiB)

I hope that this feature or at least a switch to turn it on can be added.

jemiryguo commented 1 year ago

I found that this issue was fixed in the latest version, but the latest version on Julia's official registry is still 4.0.2. Perhaps should release this feature to make it available. Thanks!

jemiryguo commented 1 year ago

I will close this issue to avoid annoying.