blue-blue272 / fewshot-CAN

Code of Cross Attention Network for Few-shot Classification (NeurIPS 2019).
206 stars 41 forks source link

An inquiry about the Loss function in train.py line123-125 #15

Open JerryHao2001 opened 1 year ago

JerryHao2001 commented 1 year ago

Hello! While reading the paper associated with the code, I noticed that in the paper (page 5 formula 5 & 6), the L1 is the "nearest neighbor classification loss" while L2 is the "global classification loss", and total loss L = lambda L1 + L2. Whereas in the code line 123-125, it seems to me the losses are a bit mixed up.

123 loss1 = criterion(ytest, pids.view(-1)) 124 loss2 = criterion(cls_scores, labels_test.view(-1)) 125 loss = loss1 + 0.5 * loss2

I'm wondering why the nearest neighbor classification loss is calculated with ytest and pids, which looks like the global loss. Thank you for the help!