Closed JaydevSR closed 10 months ago
I have presently changed the implementation such that it utilizes two kernels: One for neighbor list and one without it.
Patch coverage: 7.31%
and project coverage change: -0.87%
:warning:
Comparison is base (
17f20ed
) 73.12% compared to head (841d291
) 72.25%.:exclamation: Current head 841d291 differs from pull request most recent head b349ac1. Consider uploading reports for the commit b349ac1 to get more accurate results
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Looks like a good start. With regards to benchmarking I would use a system of 1000-5000 atoms and use @benchmark
on the force function for now.
I am facing some issues most probably with synchronization in the kernel which is giving wrong results for forces.
This stuff can be hard to debug. I would use @cushow
and @cuprintln
on a small system to print out relevant data and check it is correct and gets reduced correctly. Also, you could try a sync_threads()
after setting up and zeroing the shared memory, or before the final reduction.
It might be worth asking on the #gpu Slack channel about sparse matrices. CUDA has sparse matrices but I guess the issue is that these don't work in kernels?
I would also be interested to see the speed of using the OpenMM atom block strategy even with using the dense adjacency matrix. Another idea is to use a dense adjacency matrix, but create a sparse data structure with the first thread inside the kernel itself. You could have two shared memory int vectors for the interacting indices. Threads would then loop over those pairs.
Not the best CUDA programmer, but I thought I'd throw this out there. I'm sure I'm missing some intricacy as I have literally spent 0 minutes thinking about force kernels on a GPU but the CUDA part of Molly is something I'd like to understand more. What are downsides of using a kernel like this besides you are not manually choosing the block sizes etc? too many kernel invocatons?
function kernel(atom_idx, interacting_atoms)
return mapreduce(j -> force(atom_idx, j), +, interacting_atoms)
end
@ejmeitz Right now the kernel is essentially doing what you have pointed out. Maybe I will try and see how the map reduce compares to the current implementation. But I think the problem with just using this is map reduce is that it is well suited to do reductions when the reduced over indices are mostly independent and no extra optimizations can be made anyway.
But in our case, we are dealing with N^2 interactions where we can certainly benefit from data reuse and operation ordering. This can happen only if the kernel is doing things at a warp level so that we can shuffle things around while smartly going through the reduction and force calculation at once. I am working on this right now and trying to straighten a few bugs. I also have little experience with CUDA apart from examples for now, as far as I understand this is the main motivation to begin with this barebones implementation, I am not sure how much these optimizations will help but it will be surely worth something 😄 (unless neighbor data is way sparse that the extra overhead of loops takes over but that can be dealt with as another case?)
Bechmark results on for the forces
function (Non-tiled). The implementation loops over all the j
values for a particular i
using block stride and the force summation is made over shared memory using a binary reduction. Each block calculates one row in the interaction pairs, hence there are n_atoms
blocks.
Device: Tesla V100-PCIE-16GB
Simulation | Atoms | Total Pairs | Min (ms) | Median (ms) | Mean (ms) | Mean per 10000 pairs (μs) |
---|---|---|---|---|---|---|
NONL | 3200 | 5118400 | 2.450 | 2.577 | 2.576 | 5.03 |
NONL f32 | 3200 | 5118400 | 0.854 | 1.098 | 1.093 | 2.13 |
NL (atomic) | 3200 | 335323 | 0.215 | 0.225 | 0.226 | 6.74 |
NL f32 (atomic) | 3200 | 334149 | 0.193 | 0.210 | 0.211 | 6.31 |
Btw I have access to a cluster of A40s and another of 2080s of you wanna test on distributed systems or just on a different single GPU.
@ejmeitz Thanks that would be great to have more devices to test on. Currently I have access to V100, P100 and GTX 1080Ti but I have only tried on V100 for now. Later maybe we can benchmark on all the options.
Also, I have written the following kernel for which works by tiling a block into 32 * 32 tiles and does force summation on that diagonally forward. Each block is 1D having n_threads
threads and the grid is 2D with (n_blocks, n_blocks)
blocks covering the whole N^2 pairs as defined below.
n_threads = min(n_atoms, parse(Int, get(ENV, "MOLLY_GPUNTHREADS_PAIRWISE", "512")))
n_blocks = cld(n_atoms, n_threads)
function pairwise_force_kernel_nonl!(forces::AbstractArray{T}, coords_var, atoms_var, boundary, inters,
neighbors_var, ::Val{D}, ::Val{F}) where {T, D, F}
coords = CUDA.Const(coords_var)
atoms = CUDA.Const(atoms_var)
tidx = threadIdx().x
threads = blockDim().x
i_0_block = (blockIdx().x - 1) * threads + 1
j_0_block = (blockIdx().y - 1) * threads + 1
lane = (tidx - 1) % WARPSIZE + 1
warpidx = cld(tidx, WARPSIZE)
# @cushow tidx i_0_block j_0_block lane warpidx threads
forces_shmem = @cuStaticSharedMem(T, (3, 1024))
@inbounds for dim in 1:3
forces_shmem[dim, tidx] = zero(T)
end
# iterate over horizontal tiles of size warpsize * warpsize to cover all j's
i_0_tile = i_0_block + (warpidx - 1) * WARPSIZE
j_0_tile = j_0_block
while j_0_tile < j_0_block + threads # TODO: Check i, j in bounds
# Load data on the diagonal
i = i_0_tile + lane - 1
j = j_0_tile + lane - 1
atom_i, coord_i = atoms[i], coords[i]
tilesteps = WARPSIZE
if i_0_tile == j_0_tile # Don't compute i-i forces
j = j % (j_0_tile + WARPSIZE - 1) + 1
tilesteps -= 1
end
for _ in 1:tilesteps
sync_warp()
atom_j, coord_j = atoms[j], coords[j] # TODO: shuffle this as well
f = sum_pairwise_forces(inters, coord_i, coord_j, atom_i, atom_j, boundary, false, F)
for dim in 1:D
forces_shmem[dim, tidx] += -ustrip(f[dim])
end
j = shfl_sync(FULLMASK, j, lane + 1)
end
j_0_tile += WARPSIZE
end
sync_warp()
for dim in 1:D
Atomix.@atomic :monotonic forces[dim, i_0_block + tidx - 1] += forces_shmem[dim, tidx]
end
return nothing
end
Currently, this is only giving correct results if there is only one block. So, the tiling scheme seems to be correct and I am trying to figure out what is going wrong for multiple blocks. (embarrasing mistake, fixed it 😓) This kernel is also slower than the current one (I think mostly because of added complexity) but I hope the main benefits will come after we shuffle atom_j
and coord_j
as well.
One more possible optimization can be to remove the looping over tiles to cover all j's horizontally and redefining the grid such that each block just consists of tiles equal to number of warps?
forces
function using a tiled kernel without a neighbor listMolly.jl/main
(n_threads, n_threads)
.(WARPSIZE, n_threads)
. Simulation | Atoms | Number of Pairs | Min time (ms) | Median time (ms) | Mean time (ms) | Mean time per 10000 pairs (μs) |
---|---|---|---|---|---|---|
Approach 0 | 3072 | 4717056 | 5.193 | 5.331 | 5.339 | 11.31 |
Approach 0 f32 | 3072 | 4717056 | 5.306 | 5.466 | 5.484 | 11.63 |
Approach 1 | 3072 | 4717056 | 1.402 | 1.572 | 1.568 | 3.32 |
Approach 1 f32 | 3072 | 4717056 | 0.870 | 1.065 | 1.061 | 2.24 |
Approach 2 | 3072 | 4717056 | 2.015 | 2.213 | 2.220 | 4.71 |
Approach 2 f32 | 3072 | 4717056 | 0.502 | 0.512 | 0.513 | 1.08 |
Takeaway: getting rid of the internal loop gives 2x improvement for f32 but we have to pay the price with worse f64 performance. I think this can be because approach 2 increases the number of blocks by a factor of n_threads / WARPSIZE
which results in drop in performance to dispatch the f64 calculations to the SMs in the GPU? If this is the case then this can be improved if we want to by each warp calculating 2 or 4 tiles at a time (without a loop just multiple lines of same code).
Cool, I would focus on f32 performance since force calculation is usually safe to run in f32. I know other software has severe slowdown with f64, it might be unavoidable.
Am I right in interpreting that the f32 median time before this PR was 1.098 ms for no NL / 0.210 ms for NL and now is 0.512 ms for no NL? That's getting somewhere.
I would aim to move to the NL version soon, particularly if there are things you have learned with the no NL case that can be used there. Can approach 2 above be applied in a performant way to the case where interactions are skipped if they are not in a dense neighbour matrix?
In the current state the kernel is quite minimal and the only major optimization opportunity that I can see is to shuffle atom and coordinate data as well which will certainly require some preprocessing that I haven't thought of yet how to integrate with the whole interface.
The NONL version gives quite a motivation for similar implementation for NL as well if we can think of a way to construct some dense tiles out of the sparse adjacency matrix? I am not quite sure about this. But this will work quite naturally if we are dealing with cell lists!
@jgreener64 No the 1.098 ms is the median time for the kernel in which we simply have a for loop over all j's for a particular i (using block stride and then a binary reduction over shared memory). But for NL i haven't touched anything so yes that the same.
Looking at the benchmarks of the atomic version (before the pr) that I have included just now you can see this kernel is 10x faster that that. If we can achieve similar performance with a NL it would be certainly something to look for.
EDIT: Also the current kernel only works if the number of atoms is a multiple of of n_threads
and n_threads
is a multiple of WARPSIZE
this removes the need for if conditions for i
and j
but the general case would introduce quite some complexity which would certainly impact performance. What should be done in this case?
this kernel is 10x faster that that
Great, even better.
Yes I think the NL kernel would be based on this new no NL one, the current NL kernel is probably too simple to speed up a lot.
How easy would it be to have the above approach 2 but have another input neighbors
, the dense adjacency list, and just have a if neighbors[i, j]
switch before the force calculation? Does that have any hope of being fast? Might be worth benchmarking if it isn't too hard. You might get use out of @inbounds
around indexing expressions too.
What should be done in this case?
I would try adding the i
and j
conditions in the most simple way you can and benchmark performance. It may not be that bad? The alternative is to pad the inputs, but that adds its own complexity.
Okay I will try both the things and benchmark them 👍
if neighbors[i, j]
in the NONL kernel to calculate force with a NL, there is a significant impact and the function becomes ~3 times slower than even the NONL version (I am not exactly sure why, maybe a lot of thread divergence). That's why I have dropped the idea of doing this.NeighborListOfLists
and modified the find_neighbors
function for GPU. In terms of performance, initially I was getting very bad benchmarks for this function (~18 seconds for 3200 atoms) but somehow, I managed to bring this down to 600 ms (~which is 3 time slower than the earlier version of this function but still a lot better~ Not true, I was benchmarking CPU version by mistake, the changed version gives similar performance and allocations to the present version in Molly).find_neighbors
has a lot of allocations so if anyone has any idea on some means to bring this down then please let me know.forces function: |
Simulation | Atoms | Pairs | Min Time (ms) | Median Time (ms) | Mean Time (ms) |
---|---|---|---|---|---|---|
NONL with general n_atoms (Float64) | 3219 | 5179371 | 2.348 | 2.499 | 2.499 | |
NONL with general n_atoms (Float32) | 3219 | 5179371 | 0.504 | 0.520 | 0.530 | |
NL with if neighbors[i, j] (Float64) |
3219 | 676932 | 3.344 | 3.537 | 3.558 | |
NL with if neighbors[i, j] (Float32) |
3219 | 676932 | 1.712 | 1.896 | 1.904 |
find_neighbors function: |
Simulation | Atoms | Min Time (ms) | Median Time (ms) | Mean Time (ms) | Allocations |
---|---|---|---|---|---|---|
NeighborList |
3200 | 559.527 | 645.718 | 624.101 | Memory estimate: 29.29 MiB, allocs estimate: 745752 | |
NeighborListOfLists |
3200 | 562.857 | 637.270 | 621.885 | Memory estimate: 29.29 MiB, allocs estimate: 745752 |
Here's how LAMMPS implements neighbor lists. Might be something useful in there. The if statement might be slow because those go in serial on a GPU if it is inside your kernel.
Is the find_neighbors
function that allocates too much something in Molly right now or one of your changes? I can check it out.
I was benchmarking find_neighbors
on CPU by mistake so the statement about the newer implementation being slower is false. It performs the same so I guess there is no issue to go ahead with this.
@ejmeitz yes the the CPU version has a lot fewer allocations due to lazily updating the Neighbor list ig. But the GPU version for the DistanceNeighborFinder
constructs the Neighbor list from an adjacency matrix each time so has a lot more allocations. This may be unavoidable but I am not too sure.
One format that might be useful for the neighbor lists is a sparse array. I'm guessing your list of lists is basically doing the same thing but using one of the built in types you might be able to use some optimized functions. Granted they might not be that sparse.
For the allocations why can't you re-use the same memory from the previous calculation?
Yes NeighborListofLists
is basically just a different sparse representation of the adjacency matrix which has a 2D structure that can be used to do tiling in the pairwise force kernel. But I couldn't find this particular format in the SparseArrays.jl (they have CSC and COO).
I was planning to do the same thing of reusing the memory from previous calculation, but it was not fitting very well with how things are right now so I am just reconstructing every time. Still this can be improved in the later once I am sure that the force kernel that uses such a representation performs well enough to keep it.
I feel like a CSC is basically a list of lists and using the Julia type you don't have to worry about the your implementation being optimal since someone already did that. Also CUDA.jl has a CSC format. Not sure what algorithms they have implemented for it though.
Finch.jl might have some stuff also.
The kernel has been modified to support any general number of atoms and there seems to be no significant performance impact in doing things like this.
Great.
there is a significant impact and the function becomes ~3 times slower than even the NONL version
Fair enough, move on from that idea then.
After some thinking I have come to the conclusion that the best representation of the neighbors list will be list of lists.
Sounds like it's worth pursuing, particularly if you can see how it links with your new no NL kernel. I wouldn't worry much about find_neighbors
for now, especially since your new version hasn't regressed from the old. Just assume you have the neighbours in whatever form is best.
There is likely a better way to reuse the memory for DistanceNeighborFinder
but I wasn't motivated to find it because a) O(N^2) neighbour finding is bad even with no allocations, the cell list approach is better long term and b) the storage of the neighbours will depend on what works best for the GPU kernels, so that should be decided first.
Let me know how integrating the list of lists with the no NL kernel goes.
NeighborListOfLists
format was not suitable for a GPU kernel as only arrays that are stored inline can be passed and not arrays of arrays. That's why I have switch to a CuSparseMatrixCSC
format (@ejmeitz thanks for the suggestion). I defined a wrapper NeighborListCSC
for that to pass some relevant information for the kernel such as the maximum number of pairs per atom which is used for calculating number of threads. Then tweaking the find_neighbors
function to construct the sparse array in an efficient manner was a bit tricky but it was a success after some struggle 😅. find_neighbors function: |
Type | Atoms | Min time (ms) | Median time (ms) | Mean time (ms) | Allocations |
---|---|---|---|---|---|---|
NeighborListCSC |
3200 | 227.940 | 233.630 | 259.534 | Memory estimate: 11.59 MiB, allocs estimate: 303800. |
forces function: |
Simulation | Atoms | Pairs | Min time (ms) | Median time (ms) | Mean time (ms) |
---|---|---|---|---|---|---|
NoNeighborList (Float64) |
3219 | 5179371 | 2.391 | 2.539 | 2.537 | |
NoNeighborList (Float32) |
3219 | 5179371 | 0.505 | 0.520 | 0.521 | |
NeighborListCSC (Float64) |
3219 | 676932 | 0.315 | 0.328 | 0.328 | |
NeighborListCSC (Float32) |
3219 | 676932 | 0.208 | 0.214 | 0.215 |
Great work! Glad you've found a NL solution that works.
A thing to note that there has been just around 2.5x improvement over NONL kernel (Float32) even though the number of pairs are reduced 10x.
This is probably okay since the number of atoms is fairly small. For the force function, i.e. excluding the O(N^2) neighbour finder, the NL case should scale as O(N) whereas the no NL case should scale as O(N^2). It would be good to plot the performance of @benchmark forces($sys, $neighbors)
from 100 to 100,000+ atoms and see whether that does hold and whether memory issues arise.
On that note I tried with a small solvated protein, 16k atoms with a density like water, and found things were slower than before. I am on Julia 1.9.2 and CUDA 4.4.0, running Ubuntu with an unloaded A6000. The script is:
using Molly, CUDA, BenchmarkTools
data_dir = normpath(dirname(pathof(Molly)), "..", "data")
pdb_file = joinpath(data_dir, "6mrr_equil.pdb")
gpu = true
f32 = true
T = f32 ? Float32 : Float64
units = false
dist_neighbors = units ? T(1.2u"nm") : T(1.2)
ff_dir = joinpath(data_dir, "force_fields")
ff = MolecularForceField(
T,
joinpath.(ff_dir, ["ff99SBildn.xml", "tip3p_standard.xml", "his.xml"])...;
units=units,
)
sys = System(
pdb_file,
ff;
units=units,
gpu=gpu,
dist_neighbors=dist_neighbors,
)
neighbors = find_neighbors(sys)
sim = VelocityVerlet(dt=units ? T(0.0005u"ps") : T(0.0005), remove_CM_motion=false)
@benchmark forces($sys, $neighbors)
# With master: 2.7 ms
# With this PR: 13 ms
simulate!(sys, sim, 50)
@time simulate!(sys, sim, 500)
# With master: 3.0 s
# With this PR: 105 s
It seems that find_neighbors
is much slower and forces
is slower too. Do you get similar changes with the above script? That would help work out if the slowdown is due to a larger, dense system or due to differences in hardware/software versions. Profiling suggests that almost all the time in find_neighbors
is in the map!(x -> count(==(x), colVal), (@view colPtr[2:end]), 1:n_atoms)
line.
So, I tried the protein simulation script and see similar issues except for the force!
function where the performance is not that drastically worse. To fix this problem in the neighbor finding part I replaced the use of map!
with another gpu kernel and that improved things significantly. So using the same setup I got the following results (Julia 1.9.0, Tesla V100-PCIE-16GB, CUDA runtime 12.1):
neighbors = find_neighbors(sys)
@benchmark find_neighbors($sys)
# With master: 24.7 ms
# With this PR: 2.3 s (with map!), 106.9 ms (with count_occurances_kernel!)
@benchmark forces($sys, $neighbors)
# With master: 3.7 ms
# With this PR: 4.0 ms
simulate!(sys, sim, 50)
@time simulate!(sys, sim, 500)
# With master: 3.45 s
# With this PR: 152 s (with map!), 7.8 s (with count_occurances_kernel!)
For the plots I was trying to benchmark the forces!
function for different number of atoms but at 50,000 atoms, the GPU is out of memory:
I didn't try number of atoms between 10,000 and 50,000 so it can be the case that this happens even for less atoms. The main issue seems to be in the call to findall
on nf.neighbors
.
Also sadly I mixed up the benchmarks for master branch and the performance improvement is not there when using a neighbor list. The benchmarks for master branch are:
Simulation | Atoms | Pairs | Median time (ms) |
---|---|---|---|
NoNeighborList (Float64) | 3219 | 5179371 | 5.884 |
NoNeighborList (Float32) | 3219 | 5179371 | 6.007 |
NeighborList (Float64) | 3219 | 338466 | 0.252 |
NeighborList (Float32) | 3219 | 338466 | 0.228 |
@JaydevSR can you not just use the nf.neighbors
array as the mask you use to index into pairs
? Its already a boolean array.
There's also a bunch of allocations in find_neighbors
that I think can be pre-allocated and re-used.
Okay, so it looks like we will need a different strategy to speed things up with the neighbour list. Do you have any ideas for that at the moment?
It may be worth going back and looking at other implementations again, you might get more out of them now that you are more familiar with GPU kernels. You could also try using a GPU profiling tool like Nsight to see where time is being spent in the kernel, see https://cuda.juliagpu.org/stable/development/profiling.
As far as other approaches to deal with the neighbor list go, I'm stumped at the moment. I also thinking about taking a step back to see what can be changed.
For profiling, I am accessing the cyclops using ssh so I was not able to figure out how to use the gui of the profiling tool. Do you have any idea on how I can do that? There seems to be a remote gui support in Nsight systems, I wonder why I didn't notice that before 😓
@JaydevSR I've used the NSight Compute profiler before with julia but that will just look at the execution of a kernel and not communication with the CPU. Tells you if you're doing un-coalesced access and if you're threads are diverging and how actively you are using cache. I can help with that if needed (or run it quick locally if you send me some test code).
Also, are your benchmarks right now just for the kernel and not a whole time step? If the actual force kernel is slower on GPU then I'd start looking at the output from NSight Compute before Nsight Systems since something besides communication is the bottleneck.
The HoomD code might be worth looking at. CUDA forum thread talking about it a little as well. Original HoomD paper (outdataed)
Yes, the current benchmarks are for just the forces
function which calls the kernel after doing some things. I am using this script for the benchmarks:
using Molly
using BenchmarkTools
using CUDA
function setup_sim(nl::Bool, f32::Bool)
local n_atoms = 3219
atom_mass = f32 ? 10.0f0u"u" : 10.0u"u"
boundary = f32 ? CubicBoundary(6.0f0u"nm") : CubicBoundary(6.0u"nm")
starting_coords = place_atoms(n_atoms, boundary; min_dist=0.2u"nm")
starting_velocities = [random_velocity(atom_mass, 1.0u"K") for i in 1:n_atoms]
starting_coords_f32 = [Float32.(c) for c in starting_coords]
starting_velocities_f32 = [Float32.(c) for c in starting_velocities]
simulator = VelocityVerlet(dt=f32 ? 0.02f0u"ps" : 0.02u"ps")
neighbor_finder = NoNeighborFinder()
cutoff = DistanceCutoff(f32 ? 1.0f0u"nm" : 1.0u"nm")
pairwise_inters = (LennardJones(use_neighbors=false, cutoff=cutoff),)
if nl
neighbor_finder = DistanceNeighborFinder(
eligible=CuArray(trues(n_atoms, n_atoms)),
n_steps=10,
dist_cutoff=f32 ? 1.5f0u"nm" : 1.5u"nm",
)
pairwise_inters = (LennardJones(use_neighbors=true, cutoff=cutoff),)
end
coords = CuArray(f32 ? starting_coords_f32 : starting_coords)
velocities = CuArray(f32 ? starting_velocities_f32 : starting_velocities)
atoms = CuArray([Atom(charge=f32 ? 0.0f0 : 0.0, mass=atom_mass, σ=f32 ? 0.2f0u"nm" : 0.2u"nm",
ϵ=f32 ? 0.2f0u"kJ * mol^-1" : 0.2u"kJ * mol^-1") for i in 1:n_atoms])
sys = System(
atoms=atoms,
coords=coords,
boundary=boundary,
velocities=velocities,
pairwise_inters=pairwise_inters,
neighbor_finder=neighbor_finder,
)
return sys, simulator
end
runs = [
("GPU NONL" , [false, false]),
("GPU NONL f32" , [false, true]),
("GPU NL" , [true , false]),
("GPU NL f32", [true , true]),
]
for (name, args) in runs
println("*************** Run: $name ************************")
sys, sim = setup_sim(args...)
n_atoms = length(sys)
neighbors = find_neighbors(sys)
nbrs = isnothing(neighbors) ? n_atoms * (n_atoms - 1) ÷ 2 : length(neighbors)
println("> Total Pairs = $nbrs")
f = forces(sys, neighbors)
b = @benchmark CUDA.@sync forces($sys, $neighbors)
display(b)
end
I will also try and setup Nsight compute but if you can take a quick look at the profiles as well for the neighbor list kernel that would also be great.
Here's the GPU NL F32 output. A lot more kernels were generated than I expected. The first 3 were generated by sim_setup
, find_neighbors
launched the next like 100 (not sure why so many), and the last 3 kernels were generated by forces
(starts with "Z25pairwise_force_kernelnl"). I did run all of the kernels but the output for this one is already a lot to go through lol. If you want them I can send that, but in general they look to have similar issues to this kernel.
If you unzip the file below and double click on one of the kernels it will dump all sorts of information and things that could be optimized. I looked quickly at the pairwise_force_kernel and its only using maybe 30% of my GPU's resources. Looks like there's a lot of un-coalesced accesses and warp divergence.
Thanks a lot @ejmeitz. I will go through these and try to see if fixing these problems can improve things.
So... Out of all the different approaches for the NL kernel everything has almost the same performance even the atomic approach. I think this suggest that the computation is not making any difference and the performance is almost entirely dependent on the memory read-write. Profiles also show the same thing. I am not really sure how to deal with this memory bottleneck.
Even using the shared memory to store the atoms in between the computation does not deal with this issue. I had a look at the generated PTX code which shows that many intermediate results are stored in local memory instead in the registers and the using shared memory does not get rid of majority of memory stalls. Maybe there is some way to deal with this, but I will have to look more into this.
For now, it would be great if the NONL kernel can be merged at least and I will work on the NL kernel in a separate PR from scratch with my newly gained experience :) As of now according to the benchmarks for different system sizes, the NONL kernel is 20-30 times faster than the atomic approach so that's quite an improvement and any further improvement faces the same issue of memory read-write being so slow that computation times are not holding a candle to that.
Sounds good, I can try and review this PR next week. You will need to rebase/merge. The no NL kernel is definitely an improvement. One thing that would be nice to add is a comment by the kernel with a description/diagram of the approach used. I don't know how the kernel will play with Enzyme, but I can look into any problems there myself.
Well done for doggedly pursuing the NL kernel, I know how frustrating it can be. Something to consider is starting from the other direction, completely ablating everything, just loop over the neighbouring pairs and return zero forces. I guess that should be fast? Then see what minimal addition causes the slowdown, make a self-contained example and discuss it here, on Slack or on the CUDA.jl issues.
Thanks @jgreener64, I will rebase the incoming branch. Also should I also make similar changes to the pairwise potential energy kernel as well?
If it's easy to change the potential energy kernel then do. If it will take up anything more than a short time then focus on the NL kernel and I will implement the potential energy kernel later as a way to learn what is going on.
I am not sure why the CI runs are failing. Everything works for me locally.
This looks good bar the failure when CUDA is not available. That will be due to CUDA.attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
or maybe the CUDA.shfl_recurse
lines. You could enclose the offending code in if CUDA.functional()
.
Do the GPU tests pass for you? I can run them locally tomorrow too. Long term we should get set up on the Julia GPU CI infrastructure to test this kind of thing.
Yes the problem seems to lie with CUDA.attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
, I am not using this anyway so I have commented it out.
The GPU tests also pass for me locally.
pairwise_force_kernel!
needs to be renamed to pairwise_force_kernel_nl!
on line 313 of src/chain_rules.jl, GPU tests fail otherwise.
I am also seeing an occasional failure on the Monte Carlo anisotropic barostat GPU tests but I'm not sure it is due to this change, do you see that?
so I have commented it out
May as well remove these lines if they are not needed.
The barostat tests fail intermittently on my local machine for CPU. Im not sure that is related to changes made in this PR.
Benchmark results before any changes
Job Properties
JULIA_NUM_THREADS => 16
Results
Below is a table of this job's results, obtained by running the benchmarks. The values listed in the
ID
column have the structure[parent_group, child_group, ..., key]
, and can be used to index into the BaseBenchmarks suite to retrieve the corresponding benchmarks. The percentages accompanying time and memory values in the below table are noise tolerances. The "true" time/memory value for a given benchmark is expected to fall within this percentage of the reported value. An empty cell means that the value was zero.["interactions", "Coulomb energy"]
["interactions", "Coulomb force"]
["interactions", "HarmonicBond energy"]
["interactions", "HarmonicBond force"]
["interactions", "LennardJones energy"]
["interactions", "LennardJones force"]
["protein", "CPU parallel NL"]
["simulation", "CPU NL"]
["simulation", "CPU f32 NL"]
["simulation", "CPU f32"]
["simulation", "CPU parallel NL"]
["simulation", "CPU parallel f32 NL"]
["simulation", "CPU parallel f32"]
["simulation", "CPU parallel"]
["simulation", "CPU"]
["simulation", "GPU NL"]
["simulation", "GPU f32 NL"]
["simulation", "GPU f32"]
["simulation", "GPU"]
["spatial", "vector"]
["spatial", "vector_1D"]
Benchmark Group List
Here's a list of all the benchmark groups executed by this job:
["interactions"]
["protein"]
["simulation"]
["spatial"]
Julia versioninfo