Anton-Le / PhysicsBasedBayesianInference

Implementation of ensemble-based HMC for multiple architectures
MIT License
0 stars 0 forks source link

Results for Final Report #114

Open ThomasWarford opened 2 years ago

ThomasWarford commented 2 years ago

Hi Anton and Bruno, Our youtube video is due on 28/08 (this Sunday) and our final report is due on (31/08). As a result, I think we should run some benchmarks on GPU if possible to showcase the speedup vs number of GPUs/CPUs. I have created an animation for HMC which might be good in an introductory slide.

I am not entirely sure what plots will we use to demonstrate what we have done over the past few weeks yet. Any thoughts here are appreciated.

Thanks, Thomas

Anton-Le commented 2 years ago

I tried something fun while looking at the bugfix of the gradient:

I have run our implementation of HMC with 1024 particles on the GPU up until a final time of 110.1 with a step-size of 0.001 and compared it to NumPyro's own HMC with similar parameters (no step-size adaptation, final time and step size set and I think 1500 samples overall) for the Coin Toss model.

The result being roughly: Our implementation: 44 s Reference implementation: 4m 56s

There's major caveats to this comparison, of course. One being the generalizability (mostly guessing initial parameters) but still.... I think it's rather impressive.

For final benchmarks we'll take the simple model, harmonic oscillator, the linear accel model and possibly something larger (if I can fix that up). We can't tackle the really large models at this stage (at least without major cheating) due to only having HMC without the wrapping in a larger method as well as no step-size adaptation yet.

To visualise the process we can use the harmonic oscillator potential or the multinormal distribution - as mentioned in the past. To demonstrate the use of a stochastic model we can use the coin toss example and run HMC for a progressively growing number of time steps (varying final time) for the same initial ensemble and compute, after HMC, the mean and std. deviation of the parameters. Plotting this as a time series should show a nice convergence towards the known soluton (0.5, 0.75).

Overall I suggest that everybody thinks of good visual examples as well as benchmarks to run and we'll discuss these on Thursday.

ThomasWarford commented 2 years ago

I created this gif the other day (code in my 'gif' branch) HMC

That convergence plot sounds good to me

ThomasWarford commented 2 years ago

When profiling on CPU, do you suggest I change the number of threads used using MPI?

ThomasWarford commented 2 years ago

Number of chains running on GPU vs time could be a good plot. Or number of chains running on GPU vs time per chain.

ThomasWarford commented 2 years ago

@Anton-Le I'd be interested to know where to find the covid 19 model that you talk about. I feel like it would fit well as an example of what HMC is used for and as a demonstration of the need for speedup (since it took so long to run). Any other large and slow models are of interest too.

Thanks

Anton-Le commented 2 years ago

Profiling

When profiling for the CPU as a first step I suggest profiling the non-MPI version of the code. Reason being that any problem on the node level will also be present on the multi-node level. Furthermore consider running only a few steps and 2 particles for profiling and restricting execution to one CPU only (via the XLA flag).

Afterwards up the number of CPUs via the XLA flag and up the number of particles to ensure that you run at least 1 particle per CPU.

I would also advise using the ULA model for profiling, due to it requiring more computation as the amount of data is larger.

MPI profiling, while ultimately necessary, will likely only provide a graph that shows the features of an "embarassingly parallel problem": blocking communication at the beginning and the end.

Performance plots

In a parallel setting there are two common plots that can be generated:

  1. Strong scaling: Keep the work constant but increase the number of resources (e.g. CPUs, GPUs) and plot time over resources.
  2. Weak scaling: Keep the time constant and scale the work and resources up, then plot work over resources.

Strong scaling is what one commonly expects to see which answers the question "How long will I have to wait for my task to complete were I to use a given number of CPUs/GPUs ?" Unfortunately there is generally a limit to the amount of resources that can be used in a strong-scaling case. In our case, once we get to 1 particle per thread that's it - using more resources will be just a waste of energy.

Weak scaling is more common in HPC environments and is generally used to answer the question: "How big of a problem can I tackle given the requested resources and time-constraints of the system?" Weak scaling requires one to determine a problem size (number of steps and number of particles in our case) that utilises one CPU/GPU fully (to avoid measuring overhead) and then scale this up when going to multiple resources.

@ThomasWarford the first of your suggestions falls into the strong scaling category (up the workload until the GPU is fully occupied and the influence of the overhead is minimized). Time per chain is more of a weak scaling analysis.

Models

To my knowledge the COVID 19 model was a research model and I am currently not aware whether it has been published. I'll have a look later and provide a link, should it be in the public domain. But my guess is that the particular model I was referring to is being kept under lock and key.

ThomasWarford commented 2 years ago

Profiling

One CPU core, one particle with jit and vmap removed for "transparency." image Trace: 1cpu1particle.json.gz

A lot of time spent calculating gradients as expected. A lot under fori_loop as well, slightly confusing as I'd assume grad is calculated inside a fori_loop.

ThomasWarford commented 2 years ago

I'm unable to get a time vs num_cores plot as os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={num_cores}' doesn't appear to change the number of cores used at the moment. I'd like to try to get around this tomorrow. If I cant' get it it work I will make a convergence plot.

ThomasWarford commented 2 years ago

Working on convergence plots it looks as if the estimator for c2 is biased. It converges at a value higher than 0.75. I wonder if this is because the mean of the coin flip posterior is higher than the mode of the posterior. If I set the temperature to a low value the estimator converges at 0.75 (as expected, this is gradient descent).

I've just looked at the results predicted by numpyro and they seem to underestimate parameter 2 whilst our kernel overestimates. Not sure what to make of this.

Anton-Le commented 2 years ago

Here's the dump of the trivial timings I got on my system. run_1_1CPU_ParticleSequence.log run_1_1GPU_ParticleSequence.log run_1_2CPU_ParticleSequence.log run_1_2GPU_ParticleSequence.log run_1_4CPU_ParticleSequence.log run_2_1CPU_ParticleSequence.log run_2_1GPU_ParticleSequence.log run_2_2CPU_ParticleSequence.log run_2_2GPU_ParticleSequence.log run_2_4CPU_ParticleSequence.log run_3_1CPU_ParticleSequence.log run_3_1GPU_ParticleSequence.log run_3_2CPU_ParticleSequence.log run_3_2GPU_ParticleSequence.log

The number of particles scales exponentially from 2 to 2^15. The times you're interested in are the real times (e.g. wall-clock times): You can collect them vial cat <...>.log | grep real.

ThomasWarford commented 2 years ago

Thank you. I'm going to pay attention to the GPU results because the number of cores appears to have no effect on runtime.

ThomasWarford commented 2 years ago

Timing data in a spreadsheet: https://docs.google.com/spreadsheets/d/1hDMbZMP9fm2s0-8E1_ttEjX_vdjHreD8S11XHmLgxxw/edit?usp=sharing

scaling

ThomasWarford commented 2 years ago

Graph from the latest round of timing:

CPU:

image The behavior for 2^15 and 2^14 particles seems a little unusual, maybe other tasks were running.

GPU Linear Acceleration (JAXAVG):

image I'm not sure why this is so fast compared to the CT model on GPU in the last run of timing. The CT model took ~3minutes to run with 2^15 particles, but now it takes ~8seconds.

Note: It would be easier to process the times if they are returned as the number of seconds.

ThomasWarford commented 2 years ago

Spreadsheet with times: timings2.ods

ThomasWarford commented 2 years ago

Also interesting that 4 CPU cores with MPI appears faster than 4CPU cores without MPI. Maybe jax doesn't exploit all the cores as well as I thought it would, or perhaps JAX only used one core because mpirun was used to run the file.

image

ThomasWarford commented 2 years ago

@brunoroca260894 @Anton-Le Here is the latest speedup chart: GPU_speedup

Data: GPU_data.ods

Let me know your thoughts. I will be leaving in ~1h, but can try to create this graph in python with error bars before then if desired. I will be quite busy for the next few days, so someone else may have to do this if I'm unable to.

The improved scaling going from 8->2^16 does show overhead is contributing a lot.

Anton-Le commented 2 years ago

@ThomasWarford TBH this is not what I was imagining, at least w.r.t. scaling. The plot is kind of full of information, but not in an accesible way. It indicates that the problem scales w/o a dip for 2^16 particles, but that is not particularly noticable - mostly because the bars are so low compared to the 8-particle case.

At the very least I would add a red line indicating the "1x" speed-up (the baseline) s.t. the plot becomes easier to interpret.

Next would be the comparison of the 3 workloads: you use 2^16 particles, 1 GPU as the base case and compute speed-up w.r.t. to this - this is similar to what I did in the papers, but there I was comparing timings for different architectures for the same workload. Here you're comparing different workloads for the same architecture and the plot requires the viewer to compute high powers of 2 to compare the expected and achieved speed-up when varying the problem size. While possible, this is not a good approach for an accesible publication.

Here it'd be better, IMHO, to scale each workload to the 1GPU case and plot the resulting speed-ups together. The 1-GPU group will then be uniformly 1x, but the rest should show a more interesting variation.

Workload vs. speed-up I would put in a separate plot and use either the 1GPU timings or the 8-GPU timings for it (using the 2^16 time as base). This should yield a plot that immediately shows the effect of overhead/too small a problem.

ThomasWarford commented 2 years ago

Ok, I'll rescale this so each workload has a speedup of 1 for 1GPU.

Then I'll make another plot showing time vs number of particles on 1 GPU - the message here is that overhead is high/problem to small. I can see the argument for doing this plot with 8 GPU's but I'll do it with 1 as this shows all speedup plots should be taken with a pinch of salt (due to high overhead rel. to useful computation)

Here are the plots: gpu_speedup workload_speedup