Nanne / pytorch-NetVlad

Pytorch implementation of NetVlad including training on Pittsburgh.
427 stars 110 forks source link

sqrt() of the margin in Triplet loss #57

Closed dinarkino closed 3 years ago

dinarkino commented 3 years ago

Thank you for the work! Could you please clarify the moment with sqrt() of margin in Triplet loss? Why you do that? Do we need sqrt there?

# original paper/code doesn't sqrt() the distances, we do, so sqrt() the margin, I think :D
criterion = nn.TripletMarginLoss(margin=opt.margin**0.5, 
           p=2, reduction='sum').to(device)
Nanne commented 3 years ago

I guess the comment is a bit vague, I think this has to do with the original code using squared L2 distance, and in this code base L2 distance is used. So instead of using the same margin I use sqrt of that margin, so that it lines up with the difference in distance function.

It probably shouldn't be hardcoded to do sqrt, but don't think it's a major factor - worthwhile to experiment with nonetheless.

dinarkino commented 3 years ago

Ok, I see, thank you for the answer!