Anton-Le / PhysicsBasedBayesianInference

Implementation of ensemble-based HMC for multiple architectures
MIT License
0 stars 0 forks source link

Parallelise to multiple ensembles via MPI #89

Open Anton-Le opened 2 years ago

Anton-Le commented 2 years ago

We now have, in essence, a parallelization that can run on the GPU.

The next goal is to parallelise to multiple GPUs and/or multiple CPUs. Since, as @ThomasWarford has pointed out, pmap does not parallelise to multiple GPUs we should be able to implement multi-gpu usage by addind an MPI "wrapping" in the following sense:

We can associate each MPI process with a GPU at runtime (SLURM, for instance, allows one to set CUDA_VISIBLE_DEVICES depending on the job configuration). Hence, given a number of particles we can distribute them across all MPI ranks and create an ensemble for each rank, which is then processed as it is now (vmap). See also #63 .

This way we will have the functionality necessary for the embedding of the microcanonical ensemble (HMC) into the canonical one.

Anton-Le commented 2 years ago

As promised here're the functions necessary to add the MPI layer around the current HMC implementation. The case is "embarassingly" parallel, as we (for the moment) only have to distribute initial data to the workers and then let each worker initialize a sub-ensemble and run HMC on its part. The local estimates are then reduced to the root process (mpi rank 0) and printed/stored to file.

Initializing MPI

from mpi4py import MPI, rc
rc.threaded = True
rc.thread_level = "funneled"
comm = MPI.COMM_WORLD
rank = comm.Get_rank ()
size = comm.Get_size ()

Distributing data

Starting from the inputs read by MPI rank 0 we need to broadcast the initialization data via result, _ = mpi4jax.bcast(data, 0, comm=comm) On root, aka process 0, result is going to be just the copy of data, since broadcasting here generates a new array. Prior to the broadcast we have, of course, to ensure that only process 0 reads the inputs. We ignore the second output of the routine, which is the XLA token, as we do not need it.

Collection

For the reduction to one value available on all MPI processes: _sum, _ = mpi4jax.allreduce(_sum, op=MPI.SUM, comm = comm)

The simple accumulation of local means is achieved via accArray, _ = mpi4jax.reduce(data, MPI.SUM, 0, comm=comm)

ThomasWarford commented 2 years ago

Looks good to me

ThomasWarford commented 2 years ago

How this would be implemented: Should MPI be implemented in HMC.py or main.py?

What arguments are broadcasted? An Integrator instance and an array of masses?

Each process will need a different PRNGKey for the simulation to get different results, some thought needs to be put into this.

I will try to figure out the answers to these questions over the next few days, but any advice is appreciated.

Thanks.

Anton-Le commented 2 years ago

I'll grant you time to think about the points until tomorrow and will post the "solution" as we have used it in C++ code.

ThomasWarford commented 2 years ago

Edit: I fixed this by setting export MV2_ENABLE_AFFINITY=0 When running the following command mpiexec -n 2 /Dev/MPI/local.node python3 MPI_demo.py, I get this error:

Hello, World! I am process 0 of 2 on tom-MS-7A34.
Hello, World! I am process 1 of 2 on tom-MS-7A34.
Error in system call pthread_mutex_destroy: Device or resource busy
    src/mpi/init/initthread.c:242
Error in system call pthread_mutex_destroy: Device or resource busy
    src/mpi/init/initthread.c:242

Do you have any idea what could be the cause? I get a similar error on my laptop. Here is the python program for reference:

from mpi4py import MPI
import sys

size = MPI.COMM_WORLD.Get_size()
rank = MPI.COMM_WORLD.Get_rank()
name = MPI.Get_processor_name()

sys.stdout.write(
    "Hello, World! I am process %d of %d on %s.\n"
    % (rank, size, name))

Thanks, Thomas

Anton-Le commented 2 years ago

I can not reproduce your error on my machine. Storing your script in test_mpi.py I can run mpiexec -n 4 -f ~/local.node python3 test_mpi.py to obtain:

Hello, World! I am process 1 of 4 on BPS-A.
Hello, World! I am process 3 of 4 on BPS-A.
Hello, World! I am process 0 of 4 on BPS-A.
Hello, World! I am process 2 of 4 on BPS-A.
ThomasWarford commented 2 years ago

Sorry, to clarify I did get this to work, I had to change an environment variable.

Anton-Le commented 2 years ago

Interesting. Then your compilation appears to have MPI affinity baked in. Mine doesn't. I get the same result whether I run with or without the MV2_ENABLE_AFFINITY setting.

ThomasWarford commented 2 years ago

I've made some progress with this in feature-MPI, but it's not working just yet. I think I will focus on https://github.com/Anton-Le/PhysicsBasedBayesianInference/issues/98 before coming back to this.

Anton-Le commented 2 years ago

That is perfectly fine. I will have a look and try to run this, too, but the primary goal is now for me to rebase dev on main (to include the tests) and then rebase the feature-jax branch on the new dev. This way we should be able to run the models in parallel on the GPU - which will already be great and require further profiling...

ThomasWarford commented 2 years ago

I see, please note that there was a bug I was unable to fix before committing regarding mpi4jax.scatter.

I was using it to scatter keys but each rank received an array of different shape, which is not what I expected. I'm sure I'll figure it out once the changes to HMC are complete.

Anton-Le commented 2 years ago

MPI parallelism

Pseudo-random Number Generators

You are correct in recognising that the PRNG needs to be properly initialised in a parallel set-up. An improper initialisation and/or too short of a period of the PRNG may result in overlapping random number sequences and thus aloss of entropy - which is to be avoided.

I C++ I generally use a proper parallel PRNG from the SPRNG suite or TRNG. A rather common approach that is not guaranteed to provide non-overlapping random number sequences but does so for all means and purposes is to use the Mersenne Twister engine and initialize it with a different seed on each MPI rank and thread.

Hence in our case a first step to ensuring independent PRN sequenes would be to provide a seed (integer) for the PRNG on the rank-0 process (hard-code it, let the user provide it via arguments or use timestamps) and to broadcast that seed to all processes. Then each MPI rank modifies the seed in a deterministic fashion, i.e. seed += mpi_rank and the resulting seed is used to initialize the PRNG jax.random.PRNGKey( seed )

User-provided information

Using main.py from feature-AppliedHMC as the starting point we may fix the platform and model globally s.t. each process sets them correctly. Next we need to set up the modelDataDictionary s.t. statModel can be created. To this end the main process needs to read the data from the JSON file, distribute the keys to all processes (s.t. each process can initialize their dict) and then for each key distribute the data array ( a better way is to set up the dictionary on the rank-0 process and then distribute a copy to each process).

Next the user needs to provide the temperature, number of particles, number of iteration steps and the step size (for now, until the automatic stepsize determination is implemented). Since we are initializing a thermal cloud at a certain (arbitrary) location we also need the ability to provide the centre of this cloud in position space (i.e., the initial guess for the coordinates). All of the above need to be distributed from the rank-0 process (which will ultimately fetch the info from command line arguments or a file) to all other processes.

Since the information needs to be distributed from rank-0 to all others all of the above are broadcasts.

Results

After the broadcasts are done we initialize the statistical model, local ensemble and HMC. Prior to running HMC we compute the local weighted mean of the positions and collect the weighted means onto rank-0 (along with the partition function value Z). On the master process we then compute the weighted mean of means, convert it from the unconstrained domain to constrained domain and print it.

This procedure is repeated after running HMC, too. Hopefully the obtained estimate parameters are then closer to the true values than before.

ThomasWarford commented 2 years ago

I believe I have implemented this (although I'm sure there are changes that should be made): https://github.com/thomaswarford/PhysicsBasedBayesianInference/tree/feature-MPI

ThomasWarford commented 2 years ago

I pulled this into this repo's feature-MPI branch

Anton-Le commented 2 years ago

Alright. I'll have a look at this.

Anton-Le commented 2 years ago

After creating a wrapper script I can run in parallel on both GPUs in my system. Now to check the results and run this for multiple variations of the particle numbers.

Running on 2 GPUs until a final time of 0.1 with step size 0.001 and 512 particles:

Bias of coin 1:  0.5039773044386678
Absolute error:  0.003977304438667839
Relative error:  0.007954608877335678
Bias of coin 2:  0.7424705212244294
Absolute error:  0.007529478775570642
Relative error:  0.01003930503409419

run time around 6.8 [s].

Same but with 8192 particles:

Bias of coin 1:  0.4970526991453031
Absolute error:  0.0029473008546969104
Relative error:  0.005894601709393821
Bias of coin 2:  0.7469841989311692
Absolute error:  0.0030158010688308146
Relative error:  0.00402106809177442

runtime around 8.2 [s]

Same as above but additionally with a final time of 10.1:

Bias of coin 1:  0.49819534309256736
Absolute error:  0.0018046569074326446
Relative error:  0.0036093138148652892
Bias of coin 2:  0.7519177867051042
Absolute error:  0.0019177867051042297
Relative error:  0.002557048940138973

runtime arount 15.2 [s]

We observe convergence to the true values, a growth in run time from case 1 to case 2 by a factor of 1.2 whilst the load grows by a factor of 16. Going from case 2 to 3 the workload growth by a factor of 101, but the run time grows by a factor 1.9.

This is not unexpected and essentially confirms, that the model is way too small to actually utilise the resources properly. What these measurements are is hence mostly the overhead (i.e., noise).