Open Reapor-Yurnero opened 1 month ago
I've done something similar before with no problems.
After each epoch I basically update the dataset, bucket_batch_sampler, dataloader, then accelerate.prepare: train_dataset = CachedImageDataset( train_bucket_batch_sampler = BucketBatchSampler( train_dataloader = torch.utils.data.DataLoader( train_dataloader = accelerator.prepare(train_dataloader)
of course add the new dataset, proper arguments, and such. worked fine, no problems at all.
by "turn off model gradient", I assume you mean using with torch.no_grad() or just don't enable model.train()?
Hello Folks,
I'm new here. I'm currently trying to have a non standard training process where after each epoch, the dataset should be updated rather than the model. So basically we need to
For the first item I've tried the trivial way --- call accelerate.prepare again on the dataloader after each epoch but it doesn't work. Would love to hear any inputs. Really appreciate that!