mllam / neural-lam

Neural Weather Prediction for Limited Area Modeling
MIT License
64 stars 24 forks source link

Parallelize parameter weight computation using PyTorch Distributed #22

Closed sadamov closed 3 weeks ago

sadamov commented 1 month ago

Description

This PR introduces parallelization to the create_parameter_weights.py script using PyTorch Distributed. The main changes include:

  1. Added functions get_rank(), get_world_size(), setup(), and cleanup() to initialize and manage the distributed process group.

    • get_rank() retrieves the rank of the current process in the distributed group.
    • get_world_size() retrieves the total number of processes in the distributed group.
    • setup() initializes the distributed process group using NCCL (for GPU) or gloo (for CPU) backend.
    • cleanup() destroys the distributed process group.
  2. Modified the main() function to take rank and world_size as arguments and set up the distributed environment.

    • The device is set based on the rank and available GPUs.
    • The dataset is adjusted to ensure its size is divisible by (world_size * batch_size) using the adjust_dataset_size() function.
    • A DistributedSampler is used to partition the dataset among the processes.
  3. Parallelized the computation of means and squared values across the dataset.

    • Each process computes the means and squared values for its assigned portion of the dataset.
    • The results are gathered from all processes using dist.all_gather_object().
    • The root process (rank 0) computes the final mean, standard deviation, and flux statistics using the gathered results.
  4. Parallelized the computation of one-step difference means and squared values.

    • Similar to step 3, each process computes the difference means and squared values for its assigned portion of the dataset.
    • The results are gathered from all processes using dist.all_gather_object().
    • The final difference mean and standard deviation are computed using the gathered results.

These changes enable the script to leverage multiple processes/GPUs to speed up the computation of parameter weights, means, and standard deviations. The dataset is partitioned among the processes, and the results are gathered and aggregated by the root process.

To run the script in a distributed manner, it can be launched using Slurm.

Please review the changes and provide any feedback or suggestions.

joeloskarsson commented 1 month ago

Have you done any testing about how long time this takes on CPU vs GPU? I am curious if the gains from GPU acceleration makes up for the time needed to shuffle the data over to it for these computations.

sadamov commented 4 weeks ago

Have you done any testing about how long time this takes on CPU vs GPU? I am curious if the gains from GPU acceleration makes up for the time needed to shuffle the data over to it for these computations.

I have done tests with my 7TB (436524 samples) cosmo training data (zarr-based). As this solution scaled rather well I could reduce the runtime for the create_parameter_weights.py script from roughly 50h to 1h. I don't remember whether CPU/GPU was relevant, should do a proper assessment.

joeloskarsson commented 4 weeks ago

I was thinking that once the data is in a zarr, will this script not be replaced by just a few xarray .mean calls, that are already parallelized? So it seems good to know if we should put in the effort to do this ourselves on GPU, or just rely on xarray for that later.

sadamov commented 4 weeks ago

TLDR

I suggest to merge this with main for now for the 10x-speedup on CPU with GLOO and the newly introduced --distributed flag. GPU has no real benefit, so we can remove that if preferred. Suggestions from reviewer were implemented, feel free to double check. Should be replaced with xarray-based solution in the future.

So the smart thing would have been to wait for xarray, but instead I decided to make a little pseudo-scientific study :face_exhaling: The script is now working on CPU/GPU with slurm or locally. No data is lost any longer thanks to an improved Dataset Class that introcudes padding (I had the same idea as you @joeloskarsson). In the following I want to talk about the performance (1. Benchmarks) and the Robustness (2.).

1 Benchmarks

The script was evaluated in three different modes. Note that multi-node is currently not supported (yes, I had to stop myself at some point)

For the benchmarks I arbitrarily iterated 500-times over the meps_example resulting in a 2000 samples dataset. I tracked the exact time required to execute the full script either in slurmvia sacct or in the local terminal.

THE RESULTS: |Local: 55:23min| |Slurm-GPU: 5:56min| |Slurm-CPU:5:38|

So as Joel already suspected, it was not worth it to port these simple calculations to GPU. We can also remove the nccl option from the script if preferred.

2 Robustness

The script must produce the same statistics as the old script from main, in distributed mode and in single-task mode. To assert this, all stats were produced with the current script from main (called old), the new script in single-task mode (no prefix), and the new script in distributed mode called distributed. In the following stdout from the terminal we can see that all stats match with a tolerance of 1e-5 for new and new_distributed on CPU. That being said, the GPU-NCCL runs sometimes only have atol=1e-2 accuracy. Something that could probably be adjusted in the floating point operation settings, but not really worth it for now.

print("diff_std == diff_distributed_std:", torch.allclose(diff_std, diff_distributed_std, atol=1e-5))
...
------------------------------------------------------------------------------
diff_std == diff_distributed_std: True
diff_mean == diff_distributed_mean: True
flux_stats == flux_distributed_stats: True
parameter_mean == parameter_distributed_mean: True
parameter_std == parameter_distributed_std: True
------------------------------------------------------------------------------
diff_std == diff_old_std: True
diff_mean == diff_old_mean: True
flux_stats == flux_old_stats: True
parameter_mean == parameter_old_mean: True
parameter_std == parameter_old_std: True

Notes

My colleagues (generation TikTok) did actually complain that they cannot start training the model for 50h waiting for the stats, so having a faster script is certainly nice. If you actually want to use the --distributed feature, you will need a scheduler. I am personally using SLURM as follows (I will add a note about this to the README.md if we merge:

#!/bin/bash -l
#SBATCH --job-name=NeurWP
#SBATCH --account=s83
#SBATCH --time=02:00:00
#SBATCH --nodes=1
#SBATCH --ntasks=16
#SBATCH --partition=postproc
#SBATCH --mem=444G
#SBATCH --no-requeue
#SBATCH --exclusive
#SBATCH --output=lightning_logs/neurwp_param_out.log
#SBATCH --error=lightning_logs/neurwp_param_err.log

# Load necessary modules
conda activate neural-lam

srun -ul python create_parameter_weights.py --batch_size 16 --distributed

Okay I hope I will find pieces of this useful in other PRs, now please someone stop me from wasting more time here :rofl: