Previous PR which try to fix "train on validation data“ cause loss not correctly return in Trainer.forward() when validating.
class Trainer(nn.Module):
def forward(self, imgs, **inputs):
# ...
if not self.training:
return self.model(imgs, **inps)
else:
res = self.model(imgs, **inps)
if type(res) != list and type(res) != tuple:
res = [res]
return list(res) + list(self.calc_loss(*res, **labels))
Previous PR which try to fix "train on validation data“ cause loss not correctly return in
Trainer.forward()
when validating.