princeton-vl / pose-ae-train

Training code for "Associative Embedding: End-to-End Learning for Joint Detection and Grouping"
BSD 3-Clause "New" or "Revised" License
373 stars 76 forks source link

fix bugs #20

Closed Yishun99 closed 1 year ago

Yishun99 commented 6 years ago

Previous PR which try to fix "train on validation data“ cause loss not correctly return in Trainer.forward() when validating.

class Trainer(nn.Module):
    def forward(self, imgs, **inputs):
        # ...
        if not self.training:
            return self.model(imgs, **inps)
        else:
            res = self.model(imgs, **inps)
            if type(res) != list and type(res) != tuple:
                res = [res]
            return list(res) + list(self.calc_loss(*res, **labels))