huggingface / optimum-neuron

Easy, fast and very cheap training and inference on AWS Trainium and Inferentia chips.
Apache License 2.0
210 stars 63 forks source link

Enable use of IterableDataset when training with DDP #681

Open syl-taylor-aws opened 3 months ago

syl-taylor-aws commented 3 months ago

Feature request

Enable use of IterableDataset when training with NeuronTrainer and DDP. Or is there a design limitation that prevents this?

I can't share the project code, but see below another case for simplicity, which produces the same issue. DistributedSampler expects a dataset with known length, which a IterableDataset doesn't have by design.

Setup

OS: Ubuntu 22.04.4 LTS (kernel 6.5.0-1023-aws)

apt packages > aws-neuronx-collectives/unknown,now 2.21.46.0-69b77134b amd64 [installed] > aws-neuronx-dkms/unknown,now 2.17.17.0 amd64 [installed] > aws-neuronx-runtime-lib/unknown,now 2.21.41.0-fb1705f5f amd64 [installed] > aws-neuronx-tools/unknown,now 2.18.3.0 amd64 [installed]
pip packages > aws-neuronx-runtime-discovery==2.9 > neuronx-cc==2.14.227.0+2d4f85be > libneuronxla==2.0.2335 > torch==2.1.2 > torch-neuronx==2.1.2.2.1.0 > torch-xla==2.1.2 > transformers==4.41.1 > accelerate==0.29.2 > optimum-neuron==0.0.24 (also tested 0.0.25.dev0)

Command: torchrun --nproc_per_node=2 issue.py

Code (issue.py) ```python import torch from transformers import RobertaForCausalLM from optimum.neuron import NeuronTrainer as Trainer from optimum.neuron import NeuronTrainingArguments as TrainingArguments class CustomIterator: def __next__(self): return { "input_ids": torch.randint(0, 50265, (512,)), "labels": torch.randint(0, 50265, (512,)) } class CustomDataset(torch.utils.data.IterableDataset): def __iter__(self): return CustomIterator() dataset = CustomDataset() model = RobertaForCausalLM.from_pretrained("roberta-base") training_args = TrainingArguments(output_dir="./model", max_steps=100) trainer = Trainer( model=model, args=training_args, train_dataset=dataset ) trainer.train() ```
Issue ``` Traceback (most recent call last): File "/home/ubuntu/issue.py", line 29, in result = trainer.train() File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/optimum/neuron/trainers.py", line 1414, in train result = super().train( File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train return inner_training_loop( File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/optimum/neuron/utils/require_utils.py", line 51, in wrapper return func(*args, **kwargs) File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/optimum/neuron/trainers.py", line 686, in _inner_training_loop train_dataloader = self.get_train_dataloader() File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/transformers/trainer.py", line 897, in get_train_dataloader return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/accelerate/accelerator.py", line 1274, in prepare result = tuple( File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/accelerate/accelerator.py", line 1275, in self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/accelerate/accelerator.py", line 1149, in _prepare_one return self.prepare_data_loader(obj, device_placement=device_placement) File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/optimum/neuron/accelerate/accelerator.py", line 223, in prepare_data_loader data_loader = self._prepare_data_loader_for_distributed( File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/optimum/neuron/accelerate/accelerator.py", line 191, in _prepare_data_loader_for_distributed sampler = DistributedSampler(data_loader.dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) File "/home/ubuntu/aws_neuron_venv_pytorch/lib/python3.10/site-packages/torch/utils/data/distributed.py", line 91, in __init__ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] TypeError: object of type 'CustomDataset' has no len() ```

Motivation

Have a project for distributed training on Trainium with DDP that requires use of HuggingFace's IterableDataset (when streaming=True in load.py/load_dataset() from package datasets==2.19.0)

Your contribution

N/A. I noticed on Nvidia A100 GPUs (with transformers Trainer) that it uses accelerate.data_loader.DataLoaderDispatcher and does not use DistributedSampler.

HuggingFaceDocBuilderDev commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Thank you!

unography commented 2 weeks ago

Same issue, DDP breaks on using IterableDataset

DistributedSampler doesn't seem to work for IterableDataset. Perhaps a fix might be to use split_dataset_by_node

https://github.com/huggingface/optimum-neuron/blob/v0.0.25/optimum/neuron/accelerate/accelerator.py#L191

Instead of DistributedSampler, to do -

from datasets.distributed import split_dataset_by_node

dataset_on_curr_node = split_dataset_by_node(data_loader.dataset, rank=rank, world_size=num_replicas)

And passing this to the DataLoader without any sampler