bytedance / R2Former

Official repository for R2Former: Unified Retrieval and Reranking Transformer for Place Recognition
Apache License 2.0
83 stars 6 forks source link

about how cross-entropy loss is calculated #2

Closed LKELN closed 1 year ago

LKELN commented 1 year ago

Thank you for your excellent work, I have a question, if args.rerank_loss == 'ce': target = torch.zeros(rerank_out_pos.shape[0] * 2, dtype=torch.long).cuda() # +num_pairs target[:rerank_out_pos.shape[0]] = 1 rerank_loss = CE(torch.cat([rerank_out_pos, rerank_out_neg], dim=0), target) loss_triplet += rerank_loss # , rerank_out_mix Why do you calculate the cross-entropy loss for positive samples over and over again?

Jeff-Zilence commented 1 year ago

The reranking loss is computed based on pairs, not samples. Each (positive, negative) pair will have one reranking score, corresponding to 0/1 in cross-entropy loss.

LKELN commented 1 year ago

Thank you for your reply!