Closed shimizust closed 1 month ago
@shimizust thanks for putting up the request. I quickly read through the accelerate's prepare_dataloader function, it seems like it supports both iterative dataset and mapstyle dataset. So streamingdataset should work with accelerate.
If you use accelerate_launch or torchrun, they would set the env vars. Have you tried that? What error do you see?
Feel free to post the findings/error message here, happy to help with troubleshooting.
In my code, if I use accelerate to wrap the dataloader again, it will cause a deadlock. I think this is because the streaming dataset is already split for each individual GPU, and if I use accelerate to wrap it again, it will create an additional dataloader for each GPU on top of that, resulting in gpu^2 dataloaders.
@wangyanhui666 so is training successful if you don't wrap the dataloader, as mentioned in some previous issues?
@wangyanhui666 so is training successful if you don't wrap the dataloader, as mentioned in some previous issues?
yes, training successful. I use 1node 4gpus to train. not test multi nodes.
Okay perfect. going to close out this issue then, thanks!
@XiaohanZhangCMU adding a note for us to update FAQs in docs for hf accelerate since this has come up multiple times
🚀 Feature Request
Providing a guide on using StreamingDataset with HuggingFace accelerate and transformers.Trainer, if supported
Motivation
First, thanks for the great work on this! I attended your session at the Pytorch Conference. I wanted to try to use this, but I'm having trouble figuring out if this is compatible with the HuggingFace ecosystem (e.g accelerate for distributed training and the transformers trainer), which is being used pretty widely.
My understanding for HF-based training jobs is that a torch Dataset or IterableDataset is passed in to the Trainer. If accelerate is available, it will use accelerate to prepare the dataloader for distributed training. And in IterableDataset case, dataloading will occur on the first process 0 only, fetch all batches for all the processes, and broadcast the batches to each process.
I'm not sure if this is compatible with how StreamingDataset works.