ContextualAI / gritlm

Generative Representational Instruction Tuning
https://arxiv.org/abs/2402.09906
MIT License
538 stars 39 forks source link

CustomRandomSampler not working in huggingface Trainer and Accelerator #31

Open YanshekWoo opened 5 months ago

YanshekWoo commented 5 months ago

Issue

When I test the , it seems that the huggingface Trainer and Accelerator will replace the Sampler by a new object. Please refer to code: get_train_dataloader function in trainer and prepare_data_loader function in accelerate

When I try to print the sampler Class of dataloader before and after self.accelerator.prepare(), I get the following output:

<finetune.data.InTaskRandomSampler object at 0x7ff1a4b7c310>
<torch.utils.data.sampler.SequentialSampler object at 0x7ff1a4131c00>

Same issue can be found in https://discuss.huggingface.co/t/accelerator-prepare-replaces-custom-dataloader-sampler/43392.

Solution

A possible solution is to rewrite a torch.utils.data.distributed.DistributedSampler, and avoid using the self.accelerator.prepare in trainer. Of course it is necessary to rewrite the get_train_dataloader function in trainer .

Muennighoff commented 5 months ago

Sure feel free to open a PR