KaiyangZhou / deep-person-reid

Torchreid: Deep learning person re-identification in PyTorch.
https://kaiyangzhou.github.io/deep-person-reid/
MIT License
4.25k stars 1.14k forks source link

how to use TripletLoss_WRT #358

Open szxczyc opened 4 years ago

szxczyc commented 4 years ago

Hi ,kaiyang: i want to reproduce the AGW method and i found Weighted Regularization Triplet return loss and correct ,so should i rewrite a new image engine follow the torchreid guide?

class TripletLoss_WRT(nn.Module):
    """Weighted Regularized Triplet'."""

        ......

        dist_ap = dist_mat * is_pos
        dist_an = dist_mat * is_neg

        weights_ap = softmax_weights(dist_ap, is_pos)
        weights_an = softmax_weights(-dist_an, is_neg)
        furthest_positive = torch.sum(dist_ap * weights_ap, dim=1)
        closest_negative = torch.sum(dist_an * weights_an, dim=1)

        y = furthest_positive.new().resize_as_(furthest_positive).fill_(1)
        loss = self.ranking_loss(closest_negative - furthest_positive, y)

        # compute accuracy
        correct = torch.ge(closest_negative, furthest_positive).sum().item()
        return loss, correct
KaiyangZhou commented 4 years ago

you can follow the existing engine code, e.g. this, to write your own engine, simply to inherit the Engine class so you don't need to rewrite everything