Open YanshekWoo opened 5 months ago
When I test the , it seems that the huggingface Trainer and Accelerator will replace the Sampler by a new object. Please refer to code: get_train_dataloader function in trainer and prepare_data_loader function in accelerate
When I try to print the sampler Class of dataloader before and after self.accelerator.prepare(), I get the following output:
self.accelerator.prepare()
<finetune.data.InTaskRandomSampler object at 0x7ff1a4b7c310> <torch.utils.data.sampler.SequentialSampler object at 0x7ff1a4131c00>
Same issue can be found in https://discuss.huggingface.co/t/accelerator-prepare-replaces-custom-dataloader-sampler/43392.
A possible solution is to rewrite a torch.utils.data.distributed.DistributedSampler, and avoid using the self.accelerator.prepare in trainer. Of course it is necessary to rewrite the get_train_dataloader function in trainer .
torch.utils.data.distributed.DistributedSampler
self.accelerator.prepare
Sure feel free to open a PR
Issue
When I test the , it seems that the huggingface Trainer and Accelerator will replace the Sampler by a new object. Please refer to code: get_train_dataloader function in trainer and prepare_data_loader function in accelerate
When I try to print the sampler Class of dataloader before and after
self.accelerator.prepare()
, I get the following output:Same issue can be found in https://discuss.huggingface.co/t/accelerator-prepare-replaces-custom-dataloader-sampler/43392.
Solution
A possible solution is to rewrite a
torch.utils.data.distributed.DistributedSampler
, and avoid using theself.accelerator.prepare
in trainer. Of course it is necessary to rewrite the get_train_dataloader function in trainer .