Jutho / Strided.jl

A Julia package for strided array views and efficient manipulations thereof
Other
150 stars 13 forks source link

Correct way to parallelize this code? #9

Open cgarciae opened 4 years ago

cgarciae commented 4 years ago

Original function:

function distances_jl(data1, data2)
    data1 = deg2rad.(data1)
    data2 = deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    diff_lat = @view(lat1[:, None]) .- @view(lat2[None, :])
    diff_lng = @view(lng1[:, None]) .- @view(lng2[None, :])
    data = @. (
        sin(diff_lat / 2)^2 +
        cos(@view(lat1[:, None])) * cos(lat2) * sin(diff_lng / 2)^2
    )
    data .= @. 2.0 * 6373.0 * atan(sqrt(abs(data)), sqrt(abs(1.0 - data)))

    return reshape(data, (size(data1, 1), size(data2, 1)))
end

Version 1, runs in single core:

function distances_strided(data1, data2)
    data1 = @strided deg2rad.(data1)
    data2 = @strided deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    diff_lat = @strided(@view(lat1[:, None]) .- @view(lat2[None, :]))
    diff_lng = @strided(@view(lng1[:, None]) .- @view(lng2[None, :]))
    data = @strided(@. (
        sin(diff_lat / 2)^2 +
        cos(@view(lat1[:, None])) * cos(lat2) * sin(diff_lng / 2)^2
    ))
    @strided data .= @. 2.0 * 6373.0 * atan(sqrt(abs(data)), sqrt(abs(1.0 - data)))

    return reshape(data, (size(data1, 1), size(data2, 1)))
end

Version 2, uses all the cores but somehow runs 10x slower:

function distances_strided(data1, data2)
    data1 = @strided deg2rad.(data1)
    data2 = @strided deg2rad.(data2)
    lat1 = @strided data1[:, 1]
    lng1 = @strided data1[:, 2]
    lat2 = @strided data2[:, 1]
    lng2 = @strided data2[:, 2]
    diff_lat = @strided( lat1[:, None] .-  lat2[None, :])
    diff_lng = @strided( lng1[:, None] .-  lng2[None, :])
    data = @strided(@. (
        sin(diff_lat / 2)^2 +
        cos(lat1[:, None]) * cos(lat2) * sin(diff_lng / 2)^2
    ))
    @strided data .= @. 2.0 * 6373.0 * atan(sqrt(abs(data)), sqrt(abs(1.0 - data)))

    return reshape(data, (size(data1, 1), size(data2, 1)))
end
Jutho commented 4 years ago

What data1 and data2 do you have in mind? I.e. could you just post a runnable example, using some random generated data, so that I can test what is going on? What is None?

mcabbott commented 4 years ago

xref this thread, btw: https://discourse.julialang.org/t/improving-an-algorithm-that-compute-gps-distances/38213/38

johnomotani commented 3 years ago

Hi, I ran into what I guess is a similar issue. I've started looking at threading, and hoped to make things as simple as possible. From the comment here https://github.com/JuliaLang/julia/issues/19777#issuecomment-492457370 I thought Strided.jl might be a good solution. I made a test case to compare serial, @strided, @unsafe_strided and threaded-by-hand implementations (BTW I don't think Strided.enable_threads() was needed or made any difference, I just added it at some point to see if it would help):

using Base.Threads: @threads
using BenchmarkTools: @benchmark
using Random
using Strided

Strided.enable_threads()

function dostuff_broadcast(x)
    return @. sin(x) + cos(x) * exp(x) - exp(x^2) * sin(2*x) + tan(3*x)
end

function dostuff_strided(x)
    return @strided @. sin(x) + cos(x) * exp(x) - exp(x^2) * sin(2*x) + tan(3*x)
end

function dostuff_unsafe_strided(x)
    return @unsafe_strided @. sin(x) + cos(x) * exp(x) - exp(x^2) * sin(2*x) + tan(3*x)
end

function dostuff_threaded(x)
    result = similar(x)
    @threads for i ∈ eachindex(x)
        result[i] = sin(x[i]) + cos(x[i]) * exp(x[i]) - exp(x[i]^2) * sin(2*x[i]) + tan(3*x[i])
    end
    return result
end

n = 1000000

Random.seed!(1234)
const x = rand(Float64, n)

expected = dostuff_broadcast(x)

println("serial")
println("------")

result = @benchmark dostuff_broadcast($x)

@assert dostuff_broadcast(x) == expected

display(result)
println("")
println("")
println("")

println("strided broadcast")
println("-----------------")

result = @benchmark dostuff_strided($x)

@assert dostuff_strided(x) == expected

display(result)
println("")
println("")
println("")

println("unsafe_strided broadcast")
println("------------------------")

result = @benchmark dostuff_unsafe_strided($x)

@assert dostuff_strided(x) == expected

display(result)
println("")
println("")
println("")

println("threaded by hand")
println("----------------")

result = @benchmark dostuff_threaded($x)

@assert dostuff_threaded(x) == expected

display(result)
println("")
println("")
println("")

The result is that the @strided version seems to use all the threads, but on 10 threads is 2x slower than the serial version, while @unsafe_strided seems to be identical to the serial version. I'd hoped that @strided would let the compiler expand the broadcasting into something like the dostuff_threaded(), which is 10x faster on 10 threads - am I doing something wrong? Explicitly, the output I get is:

$ julia -O 3 --check-bounds=no -t 10 broadcast_test.jl 
serial
------
BenchmarkTools.Trial: 
  memory estimate:  7.63 MiB
  allocs estimate:  2
  --------------
  minimum time:     61.217 ms (0.00% GC)
  median time:      64.140 ms (0.00% GC)
  mean time:        69.287 ms (0.25% GC)
  maximum time:     121.405 ms (0.00% GC)
  --------------
  samples:          73
  evals/sample:     1

strided broadcast
-----------------
BenchmarkTools.Trial: 
  memory estimate:  312.82 MiB
  allocs estimate:  13000186
  --------------
  minimum time:     124.851 ms (0.00% GC)
  median time:      223.290 ms (59.69% GC)
  mean time:        224.429 ms (58.18% GC)
  maximum time:     265.213 ms (66.69% GC)
  --------------
  samples:          23
  evals/sample:     1

unsafe_strided broadcast
------------------------
BenchmarkTools.Trial: 
  memory estimate:  7.63 MiB
  allocs estimate:  2
  --------------
  minimum time:     61.101 ms (0.00% GC)
  median time:      61.872 ms (0.00% GC)
  mean time:        63.300 ms (0.36% GC)
  maximum time:     109.543 ms (0.00% GC)
  --------------
  samples:          80
  evals/sample:     1

threaded by hand
----------------
BenchmarkTools.Trial: 
  memory estimate:  7.64 MiB
  allocs estimate:  53
  --------------
  minimum time:     6.355 ms (0.00% GC)
  median time:      6.548 ms (0.00% GC)
  mean time:        6.833 ms (1.64% GC)
  maximum time:     28.334 ms (0.00% GC)
  --------------
  samples:          732
  evals/sample:     1

Edit: I'm using Strided v1.1.1 and Julia 1.5.3.

Jutho commented 3 years ago

I can reproduce your timings and observe that something is going wrong with threading. When printing out the threadid() when actually doing the work, it shows 1 is most cases, which should not be happening. I switched from using Base.Threads.@threads to Base.Threads.@spawn when tagging Strided v1.0, but maybe I should reconsider this. Going through threadingconstructs.jl in Base, it seems like @threads is manually pinning tasks to certain threads.

Jutho commented 3 years ago

Although, digging a bit deeper, I don't think that's really the issue. Giving the large number of allocations, there seems to be something going on with type inference. Note that I have my own mechanism of implementing broadcasting, which is different from the one of base, and apparently the way I do it, it fails to be inferable for such a complicated right hand side. This is not because it cannot deal with complicated functions, but it retains the whole expression and the expression itself is too complicated. So if I change your code to

f(x) = sin(x) + cos(x) * exp(x) - exp(x^2) * sin(2*x) + tan(3*x)

function dostuff_broadcast(x)
    return @. f(x)
end

function dostuff_strided(x)
    return @strided @. f(x)
end

function dostuff_unsafe_strided(x)
    return @unsafe_strided x @. f(x)
end

function dostuff_threaded(x)
    result = similar(x)
    @threads for i ∈ eachindex(x)
        result[i] = f(x[i])
    end
    return result
end

I obtain

serial
------
BenchmarkTools.Trial: 
  memory estimate:  7.63 MiB
  allocs estimate:  2
  --------------
  minimum time:     55.495 ms (0.00% GC)
  median time:      55.876 ms (0.00% GC)
  mean time:        56.341 ms (0.81% GC)
  maximum time:     59.006 ms (5.37% GC)
  --------------
  samples:          89
  evals/sample:     1

strided broadcast
-----------------
BenchmarkTools.Trial: 
  memory estimate:  7.63 MiB
  allocs estimate:  101
  --------------
  minimum time:     14.422 ms (0.00% GC)
  median time:      15.398 ms (0.00% GC)
  mean time:        16.238 ms (3.99% GC)
  maximum time:     24.254 ms (24.87% GC)
  --------------
  samples:          308
  evals/sample:     1

unsafe_strided broadcast
------------------------
BenchmarkTools.Trial: 
  memory estimate:  7.63 MiB
  allocs estimate:  101
  --------------
  minimum time:     14.804 ms (0.00% GC)
  median time:      16.082 ms (0.00% GC)
  mean time:        16.666 ms (3.73% GC)
  maximum time:     21.236 ms (18.73% GC)
  --------------
  samples:          300
  evals/sample:     1

threaded by hand
----------------
BenchmarkTools.Trial: 
  memory estimate:  7.63 MiB
  allocs estimate:  23
  --------------
  minimum time:     15.041 ms (0.00% GC)
  median time:      16.132 ms (0.00% GC)
  mean time:        16.788 ms (2.54% GC)
  maximum time:     26.156 ms (0.00% GC)
  --------------
  samples:          298
  evals/sample:     1

As a side remark, note that with @unsafe_strided, you have to specify which variables are arrays. Without this, you are basically running the same as just Base broadcast, without any effect of Strided.jl. Also, as a consequence, the result is itself a StridedView. The better strategy is

function dostuff_unsafe_strided(x)
    y = similar(x)
   @unsafe_strided y x @. y = f(x)
   return y
end

to get a normal Vector{Float64} as return type.

Jutho commented 3 years ago

I guess this has to do with some compiler heuristics on the complexity of tuples and nested parametric types.

johnomotani commented 3 years ago

Thanks for the tips @Jutho , that helps a lot!

I noticed that on 10 threads the 'by hand' dostuff_threaded is slightly faster than the strided or unsafe_strided versions (6.7ms vs. 8.6ms or 8.3ms) - I guess it makes sense that there is a bit of overhead for @strided and this trivial example doesn't make use of the striding optimisations. But just playing around, I decreased n to 1000 (so there's hardly any work to do in each thread), and the 'by hand' version still gave a decent speed up (13 microseconds vs. 55 microseconds for serial), but strided and unsafe_strided were about the same as serial (55 microseconds and 56 microseconds) - is there some heuristic that turns off the threading when the amount of work gets small? (just curious!)

Jutho commented 3 years ago

There is a constant const MINTHREADLENGTH = 1<<15 in mapreduce.jl. Thus, the minimal array length should be 32768 before any threading is used. Hence, for n=1000, no threading will be used by Strided.jl. This is indeed an overly simplistic heuristic; it might be good for a simple operation like adding two arrays, but if indeed the function that needs to be applied on every element is rather costly, threading can already provide a seizable speed-up with much smaller arrays, as you have observed.

But indeed, also the analysis performed by Strided.jl has some overhead, and is quite generic, trying to deal with a combination of permuted arrays, so that there are several nested loops without a clear preferable loop order. The simple case of a parallelising a plain broadcast operation was not my main concern when implementing this package. Maybe that particular case can be separated out early in the analysis, but probably there are already other packages which are much better at this.

johnomotani commented 3 years ago

:+1: thanks for the info!

probably there are already other packages which are much better at this.

Sorry, I'm very new to Julia... if it's obvious what any of those other packages are, I'd be interested to know.

Jutho commented 3 years ago

I have not been following very actively myself all the recent developments. I think you want to check out things like LoopVectorization.jl , which will also do other optimizations