Thank you for your excellent work, I have a question,
if args.rerank_loss == 'ce':
target = torch.zeros(rerank_out_pos.shape[0] * 2, dtype=torch.long).cuda() # +num_pairs
target[:rerank_out_pos.shape[0]] = 1
rerank_loss = CE(torch.cat([rerank_out_pos, rerank_out_neg], dim=0),
target)
loss_triplet += rerank_loss # , rerank_out_mix
Why do you calculate the cross-entropy loss for positive samples over and over again?
The reranking loss is computed based on pairs, not samples. Each (positive, negative) pair will have one reranking score, corresponding to 0/1 in cross-entropy loss.
Thank you for your excellent work, I have a question, if args.rerank_loss == 'ce': target = torch.zeros(rerank_out_pos.shape[0] * 2, dtype=torch.long).cuda() # +num_pairs target[:rerank_out_pos.shape[0]] = 1 rerank_loss = CE(torch.cat([rerank_out_pos, rerank_out_neg], dim=0), target) loss_triplet += rerank_loss # , rerank_out_mix Why do you calculate the cross-entropy loss for positive samples over and over again?