huggingface / nanotron

Minimalistic large language model 3D-parallelism training
Apache License 2.0
1.23k stars 122 forks source link

[Feature] Refactor `ParallelContext.world_rank_matrix` #77

Open NouamaneTazi opened 9 months ago

NouamaneTazi commented 9 months ago

For now we're storing the global ranks inside the world_rank_matrix attribute which is a numpy array of shape (expert_parallel_size, pipeline_parallel_size, data_parallel_size, tensor_parallel_size)

So in order to access a process' global rank using the world_rank_matrix right now we're using: https://github.com/huggingface/nanotron/blob/7c01d0f03dff537bbec79a380d861cb9934ba583/tests/test_parameters_accumulate_gradient_in_fp32.py#L346-L351

It would be cool to make it a functional call instead such as:

parallel_context.get_global_rank(expert_parallel_rank=0, pipeline_parallel_rank=0, data_parallel_rank=0, tensor_parallel_rank=0)