huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.68k stars 930 forks source link

Is it possible to Reprepare a Dataloader after each Epoch? #3015

Open Reapor-Yurnero opened 1 month ago

Reapor-Yurnero commented 1 month ago

Hello Folks,

I'm new here. I'm currently trying to have a non standard training process where after each epoch, the dataset should be updated rather than the model. So basically we need to

  1. re-prepare the dataset by accelerate after each epoch
  2. turn off the model gradient since we don't need them in this setting But I was not able to find them in the documentation.

For the first item I've tried the trivial way --- call accelerate.prepare again on the dataloader after each epoch but it doesn't work. Would love to hear any inputs. Really appreciate that!

minienglish1 commented 1 week ago

I've done something similar before with no problems.

After each epoch I basically update the dataset, bucket_batch_sampler, dataloader, then accelerate.prepare: train_dataset = CachedImageDataset( train_bucket_batch_sampler = BucketBatchSampler( train_dataloader = torch.utils.data.DataLoader( train_dataloader = accelerator.prepare(train_dataloader)

of course add the new dataset, proper arguments, and such. worked fine, no problems at all.

by "turn off model gradient", I assume you mean using with torch.no_grad() or just don't enable model.train()?