lindermanlab / S5

MIT License
248 stars 43 forks source link

Mulit-gpu training #12

Closed lthilnklover closed 6 months ago

lthilnklover commented 6 months ago

First of all, thank you for the well-organized repo! Apart from the jax installation you mentioned, it is very straight-foward to run the experiments.

However, since I am new to jax, it is not clear how to run a multi-gpu training. With the script provided, it seems only 1 GPU is operating, with minimal memory used by other GPUs.

Is there any additional measure I have to take to conduct a multi-gpu training?

Thanks in advance!

jimmysmith1919 commented 6 months ago

Hi, thanks for reaching out!

For multi-gpu training you will need to make use of JAX parallelism functionalities.

If you simply want to do data parallelism you can use jax.pmap as explained in this tutorial: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html. For our experiments in the paper we only used a single GPU, so this is all the main branch is setup for.

But for the language modeling setup on the development branch of our repo: https://github.com/lindermanlab/S5/tree/development, we did utilize pmap to distribute the data across devices. You can take a look at that branch if you want an example implementation. For example this line: https://github.com/lindermanlab/S5/blob/008bd547890a17d6fce059f5de104c0d578b101b/train.py#L130 utilizes jax.pmap.

If you instead want to shard the data and model across devices, JAX has multiple options for doing this: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html, https://jax.readthedocs.io/en/latest/notebooks/shard_map.html, https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html.

lthilnklover commented 6 months ago

Thank you for the detailed reply!