aws-samples / awsome-distributed-training

Collection of best practices, reference architectures, model training examples and utilities to train large models on AWS.
MIT No Attribution
207 stars 89 forks source link

Assigning Different Microbatches to Each Rank #425

Open purefall opened 2 months ago

purefall commented 2 months ago

Context:

We are following the FSDP example and trying to understand the mechanism behind how different microbatches are assigned to each rank during training, and specifically the role of the global_rank variable in this process.

In the code, it appears that global_rank is used as a seed for dataset shuffling, as shown below:

data = load_dataset(dataset, name=name, streaming=True, split=split, trust_remote_code=True).shuffle(42 + global_rank)

However, we encountered a few uncertainties regarding the initialization of global_rank and how it ensures non-overlapping data across ranks.

Questions:

  1. Initialization of global_rank:

    • Is global_rank meant to be passed as an argument, or is it inferred from the environment (e.g., the rank in distributed training)?
  2. Shuffling and Data Partitioning:

    • How does shuffling with global_rank ensure that different ranks receive different, non-overlapping samples? While the shuffling function modifies the random seed using global_rank, it's unclear how this alone guarantees distinct data across ranks without overlap.
  3. Use of DistributedSampler: In the current example, the DataLoader does not use a DistributedSampler, which is typically utilized to partition datasets across ranks. The DataLoader setup looks like this:

    train_dataloader = DataLoader(train_concat_dataset,
                                 batch_size=batch_size,
                                 num_workers=workers,
                                 pin_memory=True,
                                 prefetch_factor=4,
                                 timeout=600)
    • Is there any additional mechanism beyond shuffling (e.g., use of a DistributedSampler) that ensures non-overlapping data across ranks? Should we consider adding a DistributedSampler in this case?

Request:

Could you provide clarification on:

Any guidance on how to avoid potential overlap in samples across different ranks would be greatly appreciated.

pbelevich commented 2 months ago

@purefall thanks for reporting the issue, we are working on improving this example. Currently dataloading code in def create_streaming_dataloader is a mock that is not designed for production use. Answering your questions: In general FSDP dataloading setup should look like this:

local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])

sampler = DistributedSampler(your_dataset, rank=rank, num_replicas=world_size, shuffle=True)

train_dataloader = DataLoader(your_dataset,
                              sampler=sampler,
                              batch_size=batch_size,
                              num_workers=workers,
                              pin_memory=True,
                              prefetch_factor=4,
                              timeout=600)

Please refer to the PyTorch FSDP example while we are working on improving our FSDP example. Thank you!

maxschmitt commented 2 months ago

Will

rank = int(os.environ['RANK'])

lead to the same result as

import torch.distributed as dist
rank = dist.get_rank()

?