huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.97k stars 970 forks source link

[`data_loader`] Optionally also propagate set_epoch to batch sampler #3246

Closed tomaarsen closed 1 day ago

tomaarsen commented 3 days ago

What does this PR do?

When training with the transformers Trainer or related (Sentence Transformers, SpanMarker, SetFit, etc.), set_epoch is called on the dataloader. This is propagated down to the dataloader.batch_sampler.sampler if that has a set_epoch , but not to dataloader.batch_sampler.

This prevents epoch-specific generator seeding in custom batch samplers, such as the ones that are common in Sentence Transformers: https://sbert.net/docs/package_reference/sentence_transformer/sampler.html

See also https://github.com/UKPLab/sentence-transformers/issues/3069 by @antigregory which showed that set_epoch in my Batch Samplers is not called like I expected them to be.

Before submitting

Who can review?

@muellerzr (p.s. get better soon!)

HuggingFaceDocBuilderDev commented 3 days ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

tomaarsen commented 2 days ago

Can you confirm it fixes your issue?

It does indeed. I think I can write a test to go alongside this PR - it doesn't need to be ST-specific.

tomaarsen commented 2 days ago

This test roughly mirrors my use case and should act as an effective test to make sure ST should work. Feel free to update it if e.g. you don't normally use Accelerator in the data_loader tests.