earhian / Humpback-Whale-Identification-1st-

https://www.kaggle.com/c/humpback-whale-identification
641 stars 183 forks source link

triplet_loss problem #9

Open XuShoweR opened 5 years ago

XuShoweR commented 5 years ago

I want use your triplet_loss function to train a classification model with imagenet, and i got a mistake at triplet_loss.py line 70 dist_ap, relative_p_inds = torch.max((dist_mat * is_pos.float()).contiguous().view(N, -1), 1,keepdim=True) if my label is [1, 1, 2, 3] then N=4 i got RuntimeError: shape '[4, -1]' is invalid for input of size 6 if my label is [1, 1, 2, 2] and it worked Is my understanding correct?Do I need to generate labels in this form[1,1,2,2]?

earhian commented 5 years ago

Could you print dist_mat.shape and is_pos.shape? By the way, you can refer to https://github.com/huanghoujing/AlignedReID-Re-Production-Pytorch/blob/master/aligned_reid/model/loss.py.