victoresque / pytorch-template

PyTorch deep learning projects made easy.
MIT License
4.7k stars 1.08k forks source link

Iteration-based training #50

Closed ag14774 closed 5 years ago

ag14774 commented 5 years ago

Any ideas when can we expect to have iteration-based training ready? It would be a very useful feature

SunQpark commented 5 years ago

Here is quick implementation of iteration based training which replaces line 51 to 86 of the trainer.py.

        batch_per_epoch = 10000
        for batch_idx, (data, target) in enumerate(itertools.cycle(self.data_loader)):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.loss(output, target)
            loss.backward()
            self.optimizer.step()

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(...)

            # validation
            if batch_idx % batch_per_epoch == 0 and self.do_validation:               
                log = {
                    'loss': total_loss / len(self.data_loader),
                    'metrics': (total_metrics / len(self.data_loader)).tolist()
                }
                val_log = self._valid_epoch(epoch)
                log = {**log, **val_log}

                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                return log

As you may have noticed, structure of training loop differs much from that of epoch based training. I'm currently not working on iteration based training, because I could not find a clean and simple enough way to support both iteration and epoch based training.

Moreover, since pytorch dataset and data_loader is basically finite sequence, we need to make the data_loader loops infinitely. (I used itertools.cycle here for that, but this is not correct since the order of sampled batches is fixed now.) As I know, best way for doing this is implementing custom data loader as something like generator object, which I think is too much to have in this simple mnist example.

ag14774 commented 5 years ago

Here is how I implemented it. I added a step(self, epoch) function in BaseDataLoader. This step() can be used to store some internal state in the dataloader object so in each epoch it does something different. A simple example would be a dataset that implements __getitem__(self, key) as follows:

def __getitem__(self, key):
    if key >= len(self):
        raise Exception()
    np.random.seed(key + self.seed_offset)
    return np.uniform.random(0, 100)

So let's say we have 5000 samples. dataset[i] will seed a generator with i and output a random number. I defined step in the dataloader to be:

def step(self, epoch):
    self.dataset.seed_offset = epoch*len(self.dataset)

Then before the training loop in trainer.py we call self.data_loader.step(epoch). In the config file you set epochs to something very high and that's it. This is of course an example but step() can be used to do something else in the dataloader. So just an abstractmethod and leave it up to the person implementing the DataLoader to override it.

SunQpark commented 5 years ago

I added another implementation of iteration-based training in my PR #53. Can you check if this PR works for your case, @ag14774??