Closed SimpleXP closed 3 months ago
this part:
for i, data in enumerate(training_data):
img = data["image"]
img = img.to(torch.float).to(self.device)
batch_mean = torch.mean(img, dim=0)
if i == 0:
self.c = batch_mean
else:
self.c += batch_mean
self.c /= len(training_data)
should be:
for i, data in enumerate(training_data):
img = data["image"]
img = img.to(torch.float).to(self.device)
batch_mean = torch.mean(img, dim=0)
if i == 0:
self.c = batch_mean
else:
self.c += batch_mean
self.c /= len(training_data)
right?
should
self.c /= len(training_data)
in upper level with the for loop?