dataloader = DataLoader(dataset, batch_size=self.batch_size,
shuffle=True, drop_last=True,
num_workers=dataloader_workers)
with trange(train_epochs + test_epochs, disable=not verbose) as t:
for epoch in t:
for i, data in enumerate(dataloader):
optim.zero_grad()
generated_data = self.forward()
mmd = self.criterion(generated_data, data)
if not epoch % 200 and i == 0:
t.set_postfix(idx=idx, loss=mmd.item())
mmd.backward()
optim.step()
if epoch >= test_epochs:
self.score.add_(mmd.data)
return self.score.cpu().numpy() / test_epochs
Should change if epoch >= test_epochs: to if epoch >= train_epochs: to calculate the average score of the test set.
Code:
Should change
if epoch >= test_epochs:
toif epoch >= train_epochs:
to calculate the average score of the test set.