ECP-WarpX / WarpX

WarpX is an advanced electromagnetic & electrostatic Particle-In-Cell code.
https://ecp-warpx.github.io
Other
288 stars 184 forks source link

Poor scaling to multiple GPUs with electrostatic solver #5036

Open archermarx opened 1 month ago

archermarx commented 1 month ago

EDIT: see more recent results (with profiling) here

Hi all,

I'm trying to scale up an electrostatic simulation to multiple GPUs and am getting very poor results. To diagnose things, I ran uniform plasma simulations using the attached PICMI file, with the only difference between runs being the workload and the solver choice.

The results are shown in the figure below. Here, each line represents a different workload, and the y-axis is the time per step for that simulation at 1 GPU divided by the time per step at the number of GPUs as shown on the x-axis. All simulations are electromagnetic, except for the line labelled "ES".

image

As I increase the workload when using the electromagnetic solver, the simulation scaling approaches the ideal scaling, which is great. However, the same is not true for the electrostatic (ES) solver, which does not see a speedup at all even at the largest problem sizes.


Some more details

The base workload ("Work = 1") is 2 particles per cell in each dimension, with 32 x 32 x 32 cells. Work = 4096 corresponds to 8 ppc in each dimension with 128 x 128 x 128 cells, which was the largest simulation I could fit in memory on a single GPU.

I'm running WarpX on GPU with single precision particles. The nodes I'm using have 8x NVIDIA H100 GPUs, so all of these computations are on a single node. I also tested adding a second node (16 GPUs total), but the results were equally poor.


The PICMI input file is here:

from pywarpx import picmi
constants = picmi.constants

numprocs = [4,3,2]

##########################
# physics parameters
##########################

plasma_density = 1.e25
plasma_thermal_velocity = 0.01*constants.c

##########################
# numerics parameters
##########################

# --- Number of time steps
max_steps = 1000
diagnostic_interval = max_steps

# --- Grid
N = 128
nx = N
ny = N
nz = N

xmin = -20e-6
xmax = +20e-6
ymin = -20e-6
ymax = +20e-6
zmin = -20e-6
zmax = +20e-6

particles_per_cell_each_dim = [8, 8, 8]

##########################
# physics components
##########################

uniform_plasma = picmi.UniformDistribution(density = plasma_density,
                                           lower_bound = [xmin, ymin, zmin],
                                           upper_bound = [xmax, ymax, zmax],
                                           rms_velocity = [plasma_thermal_velocity, plasma_thermal_velocity, plasma_thermal_velocity]
                                           )

electrons = picmi.Species(particle_type='electron', name='electrons', initial_distribution=uniform_plasma)

##########################
# numerics components
##########################

grid = picmi.Cartesian3DGrid(
    number_of_cells = [nx, ny, nz],
    lower_bound = [xmin, ymin, zmin],
    upper_bound = [xmax, ymax, zmax],
    lower_boundary_conditions = ['periodic', 'periodic', 'periodic'],
    upper_boundary_conditions = ['periodic', 'periodic', 'periodic'],
    moving_window_velocity = [0., 0., 0.]
)

#solver = picmi.ElectromagneticSolver(grid=grid, cfl=1.)
solver = picmi.ElectrostaticSolver(grid=grid, required_precision=1e-5)

##########################
# diagnostics
##########################

field_diag1 = picmi.FieldDiagnostic(name = 'diag1',
                                    grid = grid,
                                    period = diagnostic_interval,
                                    data_list = ['Ex', 'Jx'],
                                    write_dir = '.',
                                    warpx_file_prefix = 'Python_Uniform_plt')

##########################
# simulation setup
##########################

sim = picmi.Simulation(solver = solver,
    max_steps = max_steps,
    verbose = 1,
    warpx_current_deposition_algo = 'direct',
    warpx_numprocs = numprocs,
    warpx_amrex_use_gpu_aware_mpi = True,
    time_step_size = 5e-12,
)

sim.add_species(electrons,
                layout = picmi.GriddedLayout(n_macroparticle_per_cell=particles_per_cell_each_dim, grid=grid))

sim.add_diagnostic(field_diag1)

sim.step()

My job script is here. I just change ntasks-per-node to set the number of GPUS

#SBATCH --account=#####
#SBATCH --partition=#####
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=8
#SBATCH --gpus-per-task=h100:1
#SBATCH --gpu-bind=single:1
#SBATCH --time=0-01:00:00
#SBATCH --mem=80g

# Load required modules
source ~/warpx.profile
source ${HOME}/sw/lighthouse/h100/venvs/warpx-h100/bin/activate

# Executable and input file (or python and picmi script)
EXE=python3
INPUTS=picmi.py
ARGS=""

# CPU setup
export SRUN_CPUS_PER_TASK=8
export OMP_NUM_THREADS=${SRUN_CPUS_PER_TASK}

# Run simulation
srun --cpu-bind=cores ${EXE} ${INPUTS} ${ARGS}
RemiLehe commented 1 month ago

Thanks for reporting this.

I think that the fact that the ES solver does not scale as well as the EM solver is indeed expected. The ES solver does require more MPI communications than the EM solver, and your observations are in line with what other WarpX users have seen when trying to scale the ES solver with multiple GPUs.

Nevertheless, it might still be possible to find ways to improve the scaling. One thing you could try is to set amrex.use_gpu_aware_mpi=1, as this could potentially speed up the GPU-to-GPU MPI communications. Ah, I just saw that you are already using this.

Additionally, it could be helpful if you can post the TPROF output (at the end of the WarpX simulation) for e.g. the two-GPU simulation, just to confirm that most of the time is being spent in the Poisson solver. If you have the time, it could also be interesting to use the NVIDIA profiler to check where the code is spending most of its time.

I also know that @pmessmer is interested in speeding up the ES solver in WarpX ; maybe he'd have some suggestions.

RemiLehe commented 1 month ago

Btw, @archermarx when attempting to run the Python script that you posted (but with numprocs = [1,1,1]), I get:

MLMG: Iteration 197 Fine resid/bnorm = 0.9891505021
MLMG: Iteration 198 Fine resid/bnorm = 0.9891505021
MLMG: Iteration 199 Fine resid/bnorm = 0.9891505021
MLMG: Iteration 200 Fine resid/bnorm = 0.9891505021
MLMG: Failed to converge after 200 iterations. resid, resid/bnorm = 287945.6678, 0.9891505021
amrex::Abort::0::MLMG failed. !!!

at the first iteration.

Is that your case too? Or am I missing something (e.g. are you compiling a modified/older version of WarpX? or are you using non-default compiler flags?)

archermarx commented 1 month ago

Hi Remi,

No, running on one proc, this runs to completion on my end. My compiler options are listed below. The only thing non-default I'm using (I think) is single-precision particles. I'm running on WarpX v24.07

# Build warpx
cmake -S . -B build \
        -DWarpX_LIB=ON \
        -DWarpX_APP=OFF \
        -DWarpX_MPI=ON \
        -DWarpX_COMPUTE=CUDA \
        -DWarpX_DIMS="1;2;3" \
        -DWarpX_PYTHON=ON \
        -DWarpX_PRECISION=DOUBLE \
        -DWarpX_PARTICLE_PRECISION=SINGLE

cmake --build build --target pip_install -j 20
archermarx commented 1 month ago

EDIT: issue resolved

archermarx commented 1 month ago

After resolving some issues, I have more realistic scaling results. Not nearly as bad as before, but still suboptimal. First, I show the speedup over 1 GPU for different workloads on 1, 2, 4, and 8 GPUs:

image

Next, I show how the speedup grows as a function of workload

image

TinyProf insights

I've attached tinyprof output for 1 GPU and 8 GPU to this file. Here are some of the main insights:

FillBoundary_nowait()                                 392555      8.867      10.11      14.23  17.20%
FabArray::ParallelCopy_finish()                        31000      1.413      10.11      11.51  13.92%
FillBoundary_finish()                                 392555      9.539      10.58      11.33  13.69%

This is a huge fraction. Any idea how to speed this up?


tinyprof_1gpu.txt tinyprof_8gpu.txt picmi.txt warpx_inputs.txt