Closed ag14774 closed 5 years ago
Here is quick implementation of iteration based training which replaces line 51 to 86 of the trainer.py
.
batch_per_epoch = 10000
for batch_idx, (data, target) in enumerate(itertools.cycle(self.data_loader)):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss(output, target)
loss.backward()
self.optimizer.step()
if batch_idx % self.log_step == 0:
self.logger.debug('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(...)
# validation
if batch_idx % batch_per_epoch == 0 and self.do_validation:
log = {
'loss': total_loss / len(self.data_loader),
'metrics': (total_metrics / len(self.data_loader)).tolist()
}
val_log = self._valid_epoch(epoch)
log = {**log, **val_log}
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return log
As you may have noticed, structure of training loop differs much from that of epoch based training. I'm currently not working on iteration based training, because I could not find a clean and simple enough way to support both iteration and epoch based training.
Moreover, since pytorch dataset
and data_loader
is basically finite sequence, we need to make the data_loader
loops infinitely. (I used itertools.cycle
here for that, but this is not correct since the order of sampled batches is fixed now.) As I know, best way for doing this is implementing custom data loader as something like generator object, which I think is too much to have in this simple mnist example.
Here is how I implemented it. I added a step(self, epoch)
function in BaseDataLoader. This step() can be used to store some internal state in the dataloader object so in each epoch it does something different. A simple example would be a dataset that implements __getitem__(self, key)
as follows:
def __getitem__(self, key):
if key >= len(self):
raise Exception()
np.random.seed(key + self.seed_offset)
return np.uniform.random(0, 100)
So let's say we have 5000 samples. dataset[i] will seed a generator with i and output a random number. I defined step
in the dataloader to be:
def step(self, epoch):
self.dataset.seed_offset = epoch*len(self.dataset)
Then before the training loop in trainer.py
we call self.data_loader.step(epoch)
. In the config file you set epochs to something very high and that's it. This is of course an example but step() can be used to do something else in the dataloader. So just an abstractmethod and leave it up to the person implementing the DataLoader to override it.
I added another implementation of iteration-based training in my PR #53. Can you check if this PR works for your case, @ag14774??
Any ideas when can we expect to have iteration-based training ready? It would be a very useful feature