victoresque / pytorch-template

PyTorch deep learning projects made easy.
MIT License
4.75k stars 1.09k forks source link

Some problems #67

Open TruongKhang opened 4 years ago

TruongKhang commented 4 years ago

Hi, I used your template and I had some problems as follows:

  1. In trainer.py, at line 53, the command met(output, target) should be changed to met(output, target).item(). If we don't, the memory space will increase after each batch_size. I tested with PyTorch 0.4.1 and realized this problem.
  2. In my opinion, the computed evaluation metrics are not correct. Should we add the batch_size to the update function like this:
    # at line 53
    self.train_metrics.update(met.__name__, met(output, target).item(), n=batch_size)

How do you think about that?

SunQpark commented 4 years ago
  1. That part of the trainer assumes that metric functions return python scalar, not a tensor. If some of your metric functions return tensor, OOM may happen since computation graph will continue to extend from that returned tensor. Hence, call .item() inside your metric function.
  2. That result of using n=1 yield wrong average metric, only when the batch size is not a constant. If it's constant, n just cancels out during calculation. In the example MNIST code, average metric computation will be "slightly" wrong because the last batch of each epoch has different size than other batches. Your fix on this part is correct, when that batch_size comes from current data like batch_size=output.shape[0], not from global configuration like batch_size=data_loader.batch_size
TruongKhang commented 4 years ago

yes, I agree.

sunlightsgy commented 4 years ago

I think it should be:

outputs, targets = [], []
with torch.no_grad():
    for batch_idx, (data, target) in enumerate(self.valid_data_loader):
        data, target = data.to(self.device), target.to(self.device)

        output = self.model(data)
        pred = torch.argmax(output, dim=1)
        loss = self.criterion(output, target)
        outputs.append(output.detach())
        targets.append(target.detach())
        self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
        self.valid_metrics.update('loss', loss.item())

        self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
    outputs = torch.cat(outputs, dim=0)
    targets = torch.cat(targets, dim=0)
    for met in self.metric_ftns:
        self.valid_metrics.update(met.__name__, met(outputs, targets))
TruongKhang commented 4 years ago

@sunlightsgy , I get your point. Assume that we have a dataset size N, each image has size M x L, the batch_size is B. If N%B = 0, the results of two approaches will be the same, right? But it is not true if N%B != 0.

sunlightsgy commented 4 years ago

@sunlightsgy , I get your point. Assume that we have a dataset size N, each image has size M x L, the batch_size is B. If N%B = 0, the results of two approaches will be the same, right? But it is not true if N%B != 0.

If the metric is accuracy, what you said is right. Other metrics may fail as long as len(target) is not fixed.