Open liangbh6 opened 6 years ago
Hi, I try to implement the batch-all loss with this framework, modified from triplet.py, like: `from future import absolute_import
import torch from torch import nn from torch.autograd import Variable
class Improved_TripletLoss(nn.Module): def init(self, margin=0, num_instances=4): super(Improved_TripletLoss, self).init() self.margin = margin self.num_instances = num_instances self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, inputs, targets): n = inputs.size(0) # Compute pairwise distance, replace by the official when merged dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) dist = dist + dist.t() dist.addmm_(1, -2, inputs, inputs.t()) dist = dist.clamp(min=1e-12).sqrt() # for numerical stability # For each anchor, find its positives and negatives mask = targets.expand(n, n).eq(targets.expand(n, n).t()) # to exclude the diagonal elements mask = mask + Variable(torch.eye(n)).type("torch.cuda.ByteTensor") dist_ap, dist_an = [], [] indices = torch.nonzero((mask == 1).data) for l in range(indices.size(0)): dist_an.append(dist[indices[l, 0], ][mask[indices[l, 0], ] == 0]) dist_an.append(dist[indices[l, 1], ][mask[indices[l, 1], ] == 0]) dist_ap.append(dist[indices[l, 0], indices[l, 1]].repeat(2 * (n - self.num_instances), )) dist_ap = torch.cat(dist_ap) dist_an = torch.cat(dist_an) # Compute ranking hinge loss y = dist_an.data.new() y.resize_as_(dist_an.data) y.fill_(1) y = Variable(y) loss = self.ranking_loss(dist_an, dist_ap, y) / 2 prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) return loss, prec`
but I got rank1 accuracy around 60% on Market1501 dataset, worse than softmax loss. I can't figure out why, is there anyone who do same try?
Hi, I try to implement the batch-all loss with this framework, modified from triplet.py, like: `from future import absolute_import
import torch from torch import nn from torch.autograd import Variable
class Improved_TripletLoss(nn.Module): def init(self, margin=0, num_instances=4): super(Improved_TripletLoss, self).init() self.margin = margin self.num_instances = num_instances self.ranking_loss = nn.MarginRankingLoss(margin=margin)
but I got rank1 accuracy around 60% on Market1501 dataset, worse than softmax loss. I can't figure out why, is there anyone who do same try?