Open TruongKhang opened 4 years ago
.item()
inside your metric function. n=1
yield wrong average metric, only when the batch size is not a constant. If it's constant, n
just cancels out during calculation. In the example MNIST code, average metric computation will be "slightly" wrong because the last batch of each epoch has different size than other batches. Your fix on this part is correct, when that batch_size
comes from current data like batch_size=output.shape[0]
, not from global configuration like batch_size=data_loader.batch_size
yes, I agree.
I think it should be:
outputs, targets = [], []
with torch.no_grad():
for batch_idx, (data, target) in enumerate(self.valid_data_loader):
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
pred = torch.argmax(output, dim=1)
loss = self.criterion(output, target)
outputs.append(output.detach())
targets.append(target.detach())
self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
self.valid_metrics.update('loss', loss.item())
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
outputs = torch.cat(outputs, dim=0)
targets = torch.cat(targets, dim=0)
for met in self.metric_ftns:
self.valid_metrics.update(met.__name__, met(outputs, targets))
@sunlightsgy , I get your point. Assume that we have a dataset size N, each image has size M x L, the batch_size is B. If N%B = 0, the results of two approaches will be the same, right? But it is not true if N%B != 0.
@sunlightsgy , I get your point. Assume that we have a dataset size N, each image has size M x L, the batch_size is B. If N%B = 0, the results of two approaches will be the same, right? But it is not true if N%B != 0.
If the metric is accuracy, what you said is right. Other metrics may fail as long as len(target)
is not fixed.
Hi, I used your template and I had some problems as follows:
met(output, target)
should be changed tomet(output, target).item()
. If we don't, the memory space will increase after each batch_size. I tested with PyTorch 0.4.1 and realized this problem.update
function like this:How do you think about that?