JuliaSIMD / LoopVectorization.jl

Macro(s) for vectorizing loops.
MIT License
745 stars 67 forks source link

Slow performance at one thread, deadlock at more, used to work in earlier version. #337

Closed Chiil closed 3 years ago

Chiil commented 3 years ago

I reran an old benchmark that used to work with LoopVectorization, but that gives very slow performance on one thread and deadlocks on four threads. I filed multiple bugs on earlier versions of this before, and everything at one point worked, but now I am back at a non-working version of my benchmark. This is the code. I have also a slightly modified version in which I put the rational constants in global constants and then it works fine.

## Packages
using BenchmarkTools
using LoopVectorization
using Tullio

## Macros
function make_index(a, arrays, i, j, k)
    if a in arrays
        i += (a in [ Symbol("u"), Symbol("ut")]) ? 0.5 : 0
        j += (a in [ Symbol("v"), Symbol("vt")]) ? 0.5 : 0
        k += (a in [ Symbol("w"), Symbol("wt")]) ? 0.5 : 0

        if i < 0
            i_int = convert(Int, abs(i))
            ex_i = :( i-$i_int )
        elseif i > 0
            i_int = convert(Int, i)
            ex_i = :( i+$i_int )
        else
            ex_i = :( i )
        end

        if j < 0
            j_int = convert(Int, abs(j))
            ex_j = :( j-$j_int )
        elseif j > 0
            j_int = convert(Int, j)
            ex_j = :( j+$j_int )
        else
            ex_j = :( j )
        end

        if k < 0
            k_int = convert(Int, abs(k))
            ex_k = :( k-$k_int )
        elseif k > 0
            k_int = convert(Int, k)
            ex_k = :( k+$k_int )
        else
            ex_k = :( k )
        end

        return :( $a[ $ex_i, $ex_j, $ex_k] )
    else
        return :( $a )
    end
end

function process_expr(ex, arrays, i, j, k)
    n = 1

    if (isa(ex.args[1], Symbol) && ex.args[1] == Symbol("gradx"))
        ex.args[1] = Symbol("gradx_")
        ex = :( $ex * dxi )
    elseif (isa(ex.args[1], Symbol) && ex.args[1] == Symbol("grady"))
        ex.args[1] = Symbol("grady_")
        ex = :( $ex * dyi )
    elseif (isa(ex.args[1], Symbol) && ex.args[1] == Symbol("gradz"))
        ex.args[1] = Symbol("gradz_")
        ex = :( $ex * dzi )
    end

    args = ex.args
    while n <= length(args)
        if isa(args[n], Expr)
            args[n] = process_expr(args[n], arrays, i, j, k)
            n += 1
        elseif isa(args[n], Symbol)
            if args[n] == Symbol("gradx_")
                if isa(args[n+1], Expr)
                    args[n] = copy(args[n+1])
                    push!(args, copy(args[n+1]), copy(args[n+1]))
                    args[n  ] = process_expr(args[n  ], arrays, i-1.5, j, k)
                    args[n+1] = process_expr(args[n+1], arrays, i-0.5, j, k)
                    args[n+2] = process_expr(args[n+2], arrays, i+0.5, j, k)
                    args[n+3] = process_expr(args[n+3], arrays, i+1.5, j, k)
                elseif isa(args[n+1], Symbol)
                    args[n] = args[n+1]
                    push!(args, args[n+1], args[n+1])
                    args[n  ] = make_index(args[n  ], arrays, i-1.5, j, k)
                    args[n+1] = make_index(args[n+1], arrays, i-0.5, j, k)
                    args[n+2] = make_index(args[n+2], arrays, i+0.5, j, k)
                    args[n+3] = make_index(args[n+3], arrays, i+1.5, j, k)
                end
                args[n  ] = :( (  1//24) * $(args[n  ])  )
                args[n+1] = :( (-27//24) * $(args[n+1])  )
                args[n+2] = :( ( 27//24) * $(args[n+2])  )
                args[n+3] = :( ( -1//24) * $(args[n+3])  )
                insert!(args, n, Symbol("+"))
                n += 5
            elseif args[n] == Symbol("grady_")
                if isa(args[n+1], Expr)
                    args[n] = copy(args[n+1])
                    push!(args, copy(args[n+1]), copy(args[n+1]))
                    args[n  ] = process_expr(args[n  ], arrays, i, j-1.5, k)
                    args[n+1] = process_expr(args[n+1], arrays, i, j-0.5, k)
                    args[n+2] = process_expr(args[n+2], arrays, i, j+0.5, k)
                    args[n+3] = process_expr(args[n+3], arrays, i, j+1.5, k)
                elseif isa(args[n+1], Symbol)
                    args[n] = args[n+1]
                    push!(args, args[n+1], args[n+1])
                    args[n  ] = make_index(args[n  ], arrays, i, j-1.5, k)
                    args[n+1] = make_index(args[n+1], arrays, i, j-0.5, k)
                    args[n+2] = make_index(args[n+2], arrays, i, j+0.5, k)
                    args[n+3] = make_index(args[n+3], arrays, i, j+1.5, k)
                end
                args[n  ] = :( (  1//24) * $(args[n  ])  )
                args[n+1] = :( (-27//24) * $(args[n+1])  )
                args[n+2] = :( ( 27//24) * $(args[n+2])  )
                args[n+3] = :( ( -1//24) * $(args[n+3])  )
                insert!(args, n, Symbol("+"))
                n += 5
            elseif args[n] == Symbol("gradz_")
                if isa(args[n+1], Expr)
                    args[n] = copy(args[n+1])
                    push!(args, copy(args[n+1]), copy(args[n+1]))
                    args[n  ] = process_expr(args[n  ], arrays, i, j, k-1.5)
                    args[n+1] = process_expr(args[n+1], arrays, i, j, k-0.5)
                    args[n+2] = process_expr(args[n+2], arrays, i, j, k+0.5)
                    args[n+3] = process_expr(args[n+3], arrays, i, j, k+1.5)
                elseif isa(args[n+1], Symbol)
                    args[n] = args[n+1]
                    push!(args, args[n+1], args[n+1])
                    args[n  ] = make_index(args[n  ], arrays, i, j, k-1.5)
                    args[n+1] = make_index(args[n+1], arrays, i, j, k-0.5)
                    args[n+2] = make_index(args[n+2], arrays, i, j, k+0.5)
                    args[n+3] = make_index(args[n+3], arrays, i, j, k+1.5)
                end
                args[n  ] = :( (  1//24) * $(args[n  ])  )
                args[n+1] = :( (-27//24) * $(args[n+1])  )
                args[n+2] = :( ( 27//24) * $(args[n+2])  )
                args[n+3] = :( ( -1//24) * $(args[n+3])  )
                insert!(args, n, Symbol("+"))
                n += 5
            elseif args[n] == Symbol("interpx")
                if isa(args[n+1], Expr)
                    args[n] = copy(args[n+1])
                    push!(args, copy(args[n+1]), copy(args[n+1]))
                    args[n  ] = process_expr(args[n  ], arrays, i-1.5, j, k)
                    args[n+1] = process_expr(args[n+1], arrays, i-0.5, j, k)
                    args[n+2] = process_expr(args[n+2], arrays, i+0.5, j, k)
                    args[n+3] = process_expr(args[n+3], arrays, i+1.5, j, k)
                elseif isa(args[n+1], Symbol)
                    args[n] = args[n+1]
                    push!(args, args[n+1], args[n+1])
                    args[n  ] = make_index(args[n  ], arrays, i-1.5, j, k)
                    args[n+1] = make_index(args[n+1], arrays, i-0.5, j, k)
                    args[n+2] = make_index(args[n+2], arrays, i+0.5, j, k)
                    args[n+3] = make_index(args[n+3], arrays, i+1.5, j, k)
                end
                args[n  ] = :( (-1//16) * $(args[n  ])  )
                args[n+1] = :( ( 9//16) * $(args[n+1])  )
                args[n+2] = :( ( 9//16) * $(args[n+2])  )
                args[n+3] = :( (-1//16) * $(args[n+3])  )
                insert!(args, n, Symbol("+"))
                n += 5
            elseif args[n] == Symbol("interpy")
                if isa(args[n+1], Expr)
                    args[n] = copy(args[n+1])
                    push!(args, copy(args[n+1]), copy(args[n+1]))
                    args[n  ] = process_expr(args[n  ], arrays, i, j-1.5, k)
                    args[n+1] = process_expr(args[n+1], arrays, i, j-0.5, k)
                    args[n+2] = process_expr(args[n+2], arrays, i, j+0.5, k)
                    args[n+3] = process_expr(args[n+3], arrays, i, j+1.5, k)
                elseif isa(args[n+1], Symbol)
                    args[n] = args[n+1]
                    push!(args, args[n+1], args[n+1])
                    args[n  ] = make_index(args[n  ], arrays, i, j-1.5, k)
                    args[n+1] = make_index(args[n+1], arrays, i, j-0.5, k)
                    args[n+2] = make_index(args[n+2], arrays, i, j+0.5, k)
                    args[n+3] = make_index(args[n+3], arrays, i, j+1.5, k)
                end
                args[n  ] = :( (-1//16) * $(args[n  ])  )
                args[n+1] = :( ( 9//16) * $(args[n+1])  )
                args[n+2] = :( ( 9//16) * $(args[n+2])  )
                args[n+3] = :( (-1//16) * $(args[n+3])  )
                insert!(args, n, Symbol("+"))
                n += 5
            elseif args[n] == Symbol("interpz")
                if isa(args[n+1], Expr)
                    args[n] = copy(args[n+1])
                    push!(args, copy(args[n+1]), copy(args[n+1]))
                    args[n  ] = process_expr(args[n  ], arrays, i, j, k-1.5)
                    args[n+1] = process_expr(args[n+1], arrays, i, j, k-0.5)
                    args[n+2] = process_expr(args[n+2], arrays, i, j, k+0.5)
                    args[n+3] = process_expr(args[n+3], arrays, i, j, k+1.5)
                elseif isa(args[n+1], Symbol)
                    args[n] = args[n+1]
                    push!(args, args[n+1], args[n+1])
                    args[n  ] = make_index(args[n  ], arrays, i, j, k-1.5)
                    args[n+1] = make_index(args[n+1], arrays, i, j, k-0.5)
                    args[n+2] = make_index(args[n+2], arrays, i, j, k+0.5)
                    args[n+3] = make_index(args[n+3], arrays, i, j, k+1.5)
                end
                args[n  ] = :( (-1//16) * $(args[n  ])  )
                args[n+1] = :( ( 9//16) * $(args[n+1])  )
                args[n+2] = :( ( 9//16) * $(args[n+2])  )
                args[n+3] = :( (-1//16) * $(args[n+3])  )
                insert!(args, n, Symbol("+"))
                n += 5
            else
                args[n] = make_index(args[n], arrays, i, j, k)
                n += 1
            end
        else
            n += 1
        end
    end
    return ex
end

macro fd(arrays, ex)
    i = (ex.args[1] in [ Symbol("u"), Symbol("ut")]) ? -0.5 : 0
    j = (ex.args[1] in [ Symbol("v"), Symbol("vt")]) ? -0.5 : 0
    k = (ex.args[1] in [ Symbol("w"), Symbol("wt")]) ? -0.5 : 0

    if isa(arrays, Symbol)
        arrays = :( [ $arrays ] )
    end

    ex = process_expr(ex, arrays.args, i, j, k)

    println("Generated stencil: ")
    println(ex)
    println("")

    return esc(ex)
end

## Advection, diffusion, time kernel.
function kernel!(
        ut, u, v, w,
        visc, dxi, dyi, dzi, dt,
        is, ie, js, je, ks, ke)

    @tturbo for k in ks:ke
        for j in js:je
            for i in is:ie
                @fd (ut, u, v, w) ut += (
                    - gradx(interpx(u) * interpx(u)) + visc * (gradx(gradx(u))) )
                @fd (ut, u, v, w) ut += (
                    - grady(interpx(v) * interpy(u)) + visc * (grady(grady(u))) )
                @fd (ut, u, v, w) ut += (
                    - gradz(interpx(w) * interpz(u)) + visc * (gradz(gradz(u))) )
            end
        end
    end

    @tturbo for k in ks:ke
        for j in js:je
            for i in is:ie
                @fd (ut, u) u += dt*ut
                @fd (ut, u) ut = 0
            end
        end
    end
end

## Initialize the grid.
itot = 512; jtot = 512; ktot = 512
igc = 4; jgc = 4; kgc = 4

dx = 1/itot; dy = 1/jtot; dz = 1/ktot
dxi = 1/dx; dyi = 1/dy; dzi = 1/dz
x = dx*collect(0:itot-1)
y = dy*collect(0:jtot-1)
z = dz*collect(0:ktot-1)

## Solve the problem in double precision.
visc = 1.5
dt = 1.e-3

u = zeros(Float64, (itot+2*igc, jtot+2*kgc, ktot+2*kgc))
v = zeros(Float64, (itot+2*igc, jtot+2*kgc, ktot+2*kgc))
w = zeros(Float64, (itot+2*igc, jtot+2*kgc, ktot+2*kgc))
ut = zeros(Float64, (itot+2*igc, jtot+2*kgc, ktot+2*kgc))

## Initialize with a sinus.
n_waves = 3
is = igc+1; ie = igc+ktot; js = jgc+1; je = jgc+jtot; ks = kgc+1; ke = kgc+ktot
uc = @view u[is:ie, js:je, ks:ke]
@tullio uc = sin(n_waves*2*pi*x[i]) + cos(n_waves*2*pi*y[j]) + sin(n_waves*2*pi*z[k])

## Run kernel.
is = igc+1; ie = igc+itot; js = jgc+1; je = jgc+jtot; ks = kgc+1; ke = kgc+ktot
@btime kernel!(
        $ut, $u, $v, $w,
        $visc, $dxi, $dyi, $dzi, $dt,
        $is, $ie, $js, $je, $ks, $ke)
chriselrod commented 3 years ago

Bad performance: Hitting a bunch of non-inlined fallback methods because of the rationals. Will basically have to redefine all of these fallback methods and mark them @inline...

Julia's inliner thinks llvmcall is very expensive, so it's unlikely to choose to inline a function using it. This forces basically all code using llvmcall or calling functions using llvmcall to manually add @inline if the method is supposed to be inlined.

Deadlock: Not sure yet, but things seem to be getting corrupted.

chriselrod commented 3 years ago

Deadlock: Not sure yet, but things seem to be getting corrupted.

Memory is getting corrupted because it blew ThreadingUtilitie's buffers (corrupting its state). These bluffers were blown because LoopVectorization tried to pass 170 Rational{Int64}s as function arguments.

Float64s wouldn't be passed as arguments at all (but inserted directly) and also wouldn't have caused the non-inlining bad performance problem.

So, I'm guessing that you weren't using Rationals before, and both the deadloack and bad performance started when you began using Rationals?

A fix in LoopVectorization would be to check for them when handling constant literals, and automatically convert.

Chiil commented 3 years ago

I switched to Rational because of the reason explained in a previous problem I reported in an earlier issue: https://github.com/JuliaSIMD/LoopVectorization.jl/issues/320

I have code that can run in both Float32 and Float64 and by using rationals I could (according to the Julia manual) avoid having to cast the literal constant from Float64 to Float32. I followed the advice at the bottom of this page: https://docs.julialang.org/en/v1/manual/style-guide/

chriselrod commented 3 years ago

Once I release LoopVectorization 0.12.75 (which should be in the next few hours), the examples of using rationals here will be fine. LoopVectorization should be able to convert Float64 to Float32 if it needs to, so that using Float64 won't cause a problem.

julia> using VectorizationBase

julia> vx32 = Vec(ntuple(_ -> randn(Float32), pick_vector_width(Float32))...)
Vec{16, Float32}<-1.2874495f0, -0.5067953f0, 0.31757203f0, 1.067527f0, 0.3329838f0, 0.28481305f0, 0.49678314f0, -0.010635474f0, 0.5358165f0, 0.70434624f0, -1.0242392f0, -2.159845f0, 0.041450497f0, 1.5327083f0, -0.14729454f0, 0.3033066f0>

julia> vx32 * 34. # scalar `Float64` gets demoted to `Float32` when used with vector of `Float32`s.
Vec{16, Float32}<-43.77328f0, -17.23104f0, 10.797449f0, 36.29592f0, 11.321449f0, 9.683643f0, 16.890627f0, -0.36160612f0, 18.21776f0, 23.947773f0, -34.82413f0, -73.43473f0, 1.4093169f0, 52.112083f0, -5.008014f0, 10.312425f0>

julia> -1.287 * 34. 
-43.757999999999996
Chiil commented 3 years ago

Merging the lines above

    @tturbo for k in ks:ke
        for j in js:je
            for i in is:ie
                @fd (ut, u, v, w) ut += (
                    - gradx(interpx(u) * interpx(u)) + visc * (gradx(gradx(u))) )
                @fd (ut, u, v, w) ut += (
                    - grady(interpx(v) * interpy(u)) + visc * (grady(grady(u))) )
                @fd (ut, u, v, w) ut += (
                    - gradz(interpx(w) * interpz(u)) + visc * (gradz(gradz(u))) )
            end
        end
    end

into

    @tturbo for k in ks:ke
        for j in js:je
            for i in is:ie
                @fd (ut, u, v, w) ut += (
                    - gradx(interpx(u) * interpx(u)) + visc * (gradx(gradx(u)))
                    - grady(interpx(v) * interpy(u)) + visc * (grady(grady(u)))
                    - gradz(interpx(w) * interpz(u)) + visc * (gradz(gradz(u))) )
            end
        end
    end

Results in a deadlock on 4 threads again in the most recent version.

chriselrod commented 3 years ago

Are you sure it's deadlocking? I tried and it worked, BUT it did take obscenely long to compile:

julia> @time kernel!(
         ut, u, v, w,
         visc, dxi, dyi, dzi, dt,
         is, ie, js, je, ks, ke)
151.048107 seconds (172.85 M allocations: 7.728 GiB, 11.76% gc time, 99.73% compilation time)
chriselrod commented 3 years ago

The reason is that for some reason LV decided to do something wonky to optimize the second vs the former. EDIT: It was deciding to do something wonky because of a bug.