ActiveVisionLab / DFNet

DFNet: Enhance Absolute Pose Regression with Direct Feature Matching (ECCV 2022)
https://dfnet.active.vision
MIT License
91 stars 8 forks source link

Question about Losstriplet #12

Closed shenyehui closed 9 months ago

shenyehui commented 9 months ago

Thank you for your great work. The Loss_triplet function I found in the code seems to differ from the one you described in the paper. Have you made any modifications to the loss in your code? Below is the Loss_triplet function from your code, and I look forward to your explanation!

def triplet_loss_hard_negative_mining_plus(f1, f2, margin=1.): ''' triplet loss with hard negative mining, four cases. inspired by http://www.bmva.org/bmvc/2016/papers/paper119/paper119.pdf section3.3 :param criterion: loss function :param f1: [lvl, B, C, H, W] :param f2: [lvl, B, C, H, W] :return: loss ''' criterion = nn.TripletMarginLoss(margin=margin, reduction='mean') anchor = f1 anchor_negative = torch.roll(f1, shifts=1, dims=1) positive = f2 negative = torch.roll(f2, shifts=1, dims=1)

# select in-triplet hard negative, reference: section3.3 
mse = nn.MSELoss(reduction='mean')
with torch.no_grad():
    case1 = mse(anchor, negative)
    case2 = mse(positive, anchor_negative)
    case3 = mse(anchor, anchor_negative)
    case4 = mse(positive, negative)
    distance_list = torch.stack([case1,case2,case3,case4])
    loss_case = torch.argmin(distance_list)

# perform anchor swap if necessary
if loss_case == 0:
    loss = criterion(anchor, positive, negative)
elif loss_case == 1:
    loss = criterion(positive, anchor, anchor_negative)
elif loss_case == 2:
    loss = criterion(anchor, positive, anchor_negative)
elif loss_case == 3:
    loss = criterion(positive, anchor, negative)
else:
    raise NotImplementedError
return loss
chenusc11 commented 9 months ago

Hi, thank you for your interest in our work!

Unfortunately, I'm not sure I understood your question very well. Could you elaborate on your question a bit? What is the difference between this with what we described in the paper (Eq.3 and 4)?

shenyehui commented 9 months ago

Hi, thank you for your interest in our work!

Unfortunately, I'm not sure I understood your question very well. Could you elaborate on your question a bit? What is the difference between this with what we described in the paper (Eq.3 and 4)?

I'm sorry for not making my question clear. What I want to ask is: in the paper, the inputs to Ltriplet are M^p_real, M^p_syn, M^(p^-)_real, M^(p^-)_syn. But in the code (loss_f = triplet_loss_hard_negative_mining_plus(features_rgb, features_target, margin=args.triplet_margin)), why are only M^p_real and M^p_syn used as inputs?

chenusc11 commented 9 months ago
  1. The negative samples are sampled by torch.roll() here. L410 and L412.

  2. Then, in L417 to 422, we find negative pairs that have negative distance, referring to Equation 4.

  3. Lastly, we select specific negative pairs with minimum negative distance into the triplet loss for that iteration.

Notice that the loss is designed to minimize the feature-metric distance between synthetic and real images with the same camera poses and maximize the distance between images with different camera poses.

https://github.com/ActiveVisionLab/DFNet/blob/1389760f770851a77e601af1312f19fe065bd185/script/feature/misc.py#L409C3-L409C3

shenyehui commented 9 months ago
  1. The negative samples are sampled by torch.roll() here. L410 and L412.
  2. Then, in L417 to 422, we find negative pairs that have negative distance, referring to Equation 4.
  3. Lastly, we select specific negative pairs with minimum negative distance into the triplet loss for that iteration.

Notice that the loss is designed to minimize the feature-metric distance between synthetic and real images with the same camera poses and maximize the distance between images with different camera poses.

https://github.com/ActiveVisionLab/DFNet/blob/1389760f770851a77e601af1312f19fe065bd185/script/feature/misc.py#L409C3-L409C3

Thank you for your answer. What I would like to ask is, in the code, are negative samples collected using torch.roll() representing images with different camera poses, rather than actually inputting images with different camera poses?

chenusc11 commented 9 months ago

hi, yes that’s correct.

shenyehui commented 9 months ago

hi, yes that’s correct.

Thank you for your patient explanations. I will close this question.