tjddus9597 / Proxy-Anchor-CVPR2020

Official PyTorch Implementation of Proxy Anchor Loss for Deep Metric Learning, CVPR 2020
MIT License
314 stars 60 forks source link

Inference Get Class Prediction #20

Closed oggyfaker closed 3 years ago

oggyfaker commented 3 years ago

Hi author ! thanks for the nice repo ! I am the newbie in metric learning so i am very confuse when find the way to predict the ID (class) of image in CUB datasets. After reading in evaluate.py, i found in evaluate_cos function , i don't know the meaning of combine the embedding (after l2 norm) result in F.linear(X,X) and use it like the base for ground true (target). It looks like you use ground true to eval the model , not the prediction from the model !
image

Can you add more the file for inference demo or point out the way directly for solving issue in comment ? Forgive my "dummy question" and hope to see your answer soon !

kdwonn commented 3 years ago

Hello! Thank you for your interest in our work.
Though we don't have any demo file that shows the evaluation protocol, I think I can give you a brief explanation of the code.

  1. cos_sim = F.linear(X,X) : Internal implementation of the torch.functional.linear(a, b) is same to a @ b.t(), where @ is operator for matrix multiplication. Therefore, cos_sim is the tensor that contains every pairwise cosine similarity between embedding in X.

  2. cos_sim.topk(1 + K)[1][:,1:] : As we already know, the main objective of metric learning is learning the embedding space that encodes samples from the same class as the nearby embedding vectors. Since PA uses L2-normalized space, we use cosine similarity to find the nearest neighborhoods in the dataset. cos_sim.topk(1 + K)[1] returns indexes of top-k elements and we exclude the largest one through [:,1:] since it will be always a query itself.

  3. Y=T[cos_sim.topk(1 + K)[1][:,1:]] : So, at last, what Y is containing is the label of K nearest neighborhoods in embedding space for a given query. In calc_recall_at_k, code checks whether a sample from the same class with query exists or not in top-k neighbors, which is the definition of recall@k measure.

  4. Evaluation code has to use ground truth. It checks the GT label to determine whether the query and retrieved embedding vector are from the same class.

I hope it would help. Please leave an additional comment if you have other questions. :smile: