BoyuanJiang / matching-networks-pytorch

Matching Networks for one shot learning
228 stars 67 forks source link

Cosine distance calculation problem #9

Open ouc-lq opened 3 years ago

ouc-lq commented 3 years ago

In the source code, the author calculates the cosine distance as follows.

    sum_support = torch.sum(torch.pow(support_image, 2), 1) 
    support_manitude = sum_support.clamp(eps, float("inf")).rsqrt() 
    dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
    cosine_similarity = dot_product * support_manitude * input_manitude
    similarities.append(cosine_similarity)

But in my opinion, the right the cosine distance should be calculated as follows.

    sum_support = torch.sum(torch.pow(support_image, 2), 1) 
    support_manitude = sum_support.clamp(eps, float("inf")).rsqrt() 
    sum_input = torch.sum(torch.pow(input_image, 2), 1)
    input_manitude = sum_input.clamp(eps, float("inf")).rsqrt()
    dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
    cosine_similarity = dot_product * support_manitude * input_manitude
    similarities.append(cosine_similarity)

Am i right? If not, what is the mistake?

ouc-lq commented 3 years ago

@BoyuanJiang