m3g / CellListMap.jl

Flexible implementation of cell lists to map the calculations of particle-pair dependent functions, such as forces, energies, neighbor lists, etc.
https://m3g.github.io/CellListMap.jl/
MIT License
87 stars 4 forks source link

Error on cross pair velocity histogram #60

Closed dforero0896 closed 2 years ago

dforero0896 commented 2 years ago

Hi, I am reimplementing something very similar to this https://discourse.julialang.org/t/pairwise-computation-slower-than-python-cython-code-balltree-very-slow/62273/24 for two different particle samples. Below the relevant parts of my implementation

using StaticArrays
using LinearAlgebra
using CellListMap

function coordinate_separation(a, b, box_size)
    delta = abs(a - b)
    return (delta > 0.5*box_size ? delta - box_size : delta)*sign(a-b)
end

function separation_vector(a, b, box_size)
    return SVector{3,Float32}(ntuple(i -> coordinate_separation(a[i],b[i],box_size[i]), 3))
end

function map_function(x, y, i, j, d2, output, vx, vy, bin_edges, box_size)
    s_vector = separation_vector(x, y, box_size)
    norm = sqrt(d2)
    bin_id = searchsortedfirst(bin_edges, norm) - 1
    output[1][bin_id] += (LinearAlgebra.dot(vx - vy, s_vector) / norm)
    output[2][bin_id] += 1
    return output
end

function reduce_hist(hist,hist_threaded)
    hist = hist_threaded[1]
    for i in 2:Threads.nthreads()
     hist[1] .+= hist_threaded[i][1]
     hist[2] .+= hist_threaded[i][2]
    end
    return hist
  end

function pairwise_vel_cellist(sample_1::Vector{SVector{3,Float32}},
                                    vel_1::Vector{SVector{3,Float32}}, 
                                    sample_2::Vector{SVector{3,Float32}},
                                    vel_2::Vector{SVector{3,Float32}},
                                    bin_edges::Vector,
                                    box_size::SVector{3},
                                    )
    max_dist = bin_edges[end]
    n_bins = size(bin_edges)
    box = Box(box_size, max_dist)
    cl = CellList(sample_1, sample_2, box)
    print(cl)
    output = (zeros(n_bins),
            zeros(Int32, n_bins))
    map_pairwise!(
        (x, y, i, j, d2, output) -> map_function(x, y, i, j, d2, output, vel_1, vel_2, bin_edges, box_size),
        output, box, cl,
        reduce = reduce_hist,
        show_progress = true
    )

    return output[2], output[1]
end

When I call the function pairwise_vel_cellist on my data I get the following error

ERROR: LoadError: TaskFailedException

    nested task error: DimensionMismatch("dimensions must match: a has dims (Base.OneTo(72945),), b has dims (Base.OneTo(50383),), mismatch at 1")
    Stacktrace:
     [1] promote_shape
       @ ./indices.jl:178 [inlined]

and a much longer trace. But it seems that the complaint is because one set of particles has one dimension (~50k) and the other has ~70k particles.

Have you encountered this issue? I suspect I'm overlooking some detail but I'm pretty new to Julia so it's kind of hard to find.

Thanks in advance!

lmiq commented 2 years ago

The error there is in this line:

output[1][bin_id] += (LinearAlgebra.dot(vx - vy, s_vector) / norm)

because vx and vy are the arrays of size 70k and 50k. You probably want vx[i] and vy[j] there.

This code then runs:

using StaticArrays
using LinearAlgebra
using CellListMap

function coordinate_separation(a, b, box_size)
    delta = abs(a - b)
    return (delta > 0.5*box_size ? delta - box_size : delta)*sign(a-b)
end

function separation_vector(a, b, box_size)
    return SVector{3,Float32}(ntuple(i -> coordinate_separation(a[i],b[i],box_size[i]), 3))
end

function map_function(x, y, i, j, d2, output, vx, vy, bin_edges, box_size)
    s_vector = separation_vector(x, y, box_size)
    norm = sqrt(d2)
    bin_id = searchsortedfirst(bin_edges, norm)
    output[1][bin_id] += (LinearAlgebra.dot(vx[i] - vy[j], s_vector) / norm)
    output[2][bin_id] += 1
    return output
end

function reduce_hist(hist,hist_threaded)
    hist[1] .= hist_threaded[1][1]
    hist[2] .= hist_threaded[1][2]
    for i in 2:length(hist_threaded)
     hist[1] .+= hist_threaded[i][1]
     hist[2] .+= hist_threaded[i][2]
    end
    return hist
  end

function pairwise_vel_cellist(sample_1::Vector{SVector{3,Float32}},
                              vel_1::Vector{SVector{3,Float32}}, 
                              sample_2::Vector{SVector{3,Float32}},
                              vel_2::Vector{SVector{3,Float32}},
                              bin_edges::Vector,
                              box_size::SVector{3},
                             )
    max_dist = bin_edges[end]
    n_bins = size(bin_edges)
    box = Box(box_size, max_dist)
    cl = CellList(sample_1, sample_2, box)
    print(cl)
    output = (zeros(n_bins), zeros(Int32, n_bins))
    map_pairwise!(
        (x, y, i, j, d2, output) -> map_function(x, y, i, j, d2, output, vel_1, vel_2, bin_edges, box_size),
        output, box, cl,
        reduce = reduce_hist,
        show_progress = true
    )

    return output[2], output[1]
end

function set_data(T)
    return (
        rand(T, 70000),
        rand(T, 70000),
        rand(T, 50000),
        rand(T, 50000),
        [ 0.01 * i for i in 1:11 ],
        T(1,1,1)
    )
end

Run with:

julia> pairwise_vel_cellist(set_data(SVector{3,Float32})...)

I have fixed the histogram binning for this example (remove a -1 from bin_id = searchsortedfirst(bin_edges, norm), to be consistent with the histogram I was generating), and the fact that the reduction function was not mutating the input hist, but assigning it again. Also important is to reduce from 2:length(hist_threaded), and not 2:Thread.nthreads(), because the length of the threaded output is not necessarily equal to the number of thread anymore (see https://m3g.github.io/CellListMap.jl/stable/parallelization/#Number-of-batches)

That said, I think there are some code simplifications possible:

The x and y that are given by the mapped function al already wrapped, so probably you don't need the coordinate_separation and separation_vector functions. You can just use x - y as the separation vector. With that, you have:

function map_function2(x, y, i, j, d2, output, vx, vy, bin_edges, box_size)
    s_vector = x - y # note here
    norm = sqrt(d2)
    bin_id = searchsortedfirst(bin_edges, norm)
    output[1][bin_id] += (LinearAlgebra.dot(vx[i] - vy[j], s_vector) / norm)
    output[2][bin_id] += 1
    return output
end

(and you can remove the box_size parameter of this function).

Thus, your complete code could be:

using StaticArrays
using LinearAlgebra
using CellListMap

function map_function(x, y, i, j, d2, output, vx, vy, bin_edges)
    norm = sqrt(d2)
    bin_id = searchsortedfirst(bin_edges, norm)
    output[1][bin_id] += (LinearAlgebra.dot(vx[i] - vy[j], (x - y)) / norm)
    output[2][bin_id] += 1
    return output
end

function reduce_hist(hist,hist_threaded)
    hist[1] .= hist_threaded[1][1]
    hist[2] .= hist_threaded[1][2]
    for i in 2:length(hist_threaded)
     hist[1] .+= hist_threaded[i][1]
     hist[2] .+= hist_threaded[i][2]
    end
    return hist
  end

function pairwise_vel_cellist(sample_1::Vector{SVector{3,Float32}},
                              vel_1::Vector{SVector{3,Float32}}, 
                              sample_2::Vector{SVector{3,Float32}},
                              vel_2::Vector{SVector{3,Float32}},
                              bin_edges::Vector,
                              box_size::SVector{3},
                             )
    max_dist = bin_edges[end]
    n_bins = size(bin_edges)
    box = Box(box_size, max_dist)
    cl = CellList(sample_1, sample_2, box)
    print(cl)
    output = (zeros(n_bins), zeros(Int32, n_bins))
    map_pairwise!(
        (x, y, i, j, d2, output) -> map_function(x, y, i, j, d2, output, vel_1, vel_2, bin_edges),
        output, box, cl,
        reduce = reduce_hist,
        show_progress = true
    )
    return output[2], output[1]
end

Finally, I strongly recommend you to follow the examples given in the user manual:

https://m3g.github.io/CellListMap.jl/stable/examples/

and on this site:

https://github.com/m3g/CellListMapArticleCodes

This one is particularly relevant to you:

https://m3g.github.io/CellListMapArticleCodes/CodeBlock10.jl.html

Because the examples on the discourse threads may be out of date.

dforero0896 commented 2 years ago

Thanks a lot!