gitabcworld / MatchingNetworks

This repo provides pytorch code which replicates the results of the Matching Networks for One Shot Learning paper on the Omniglot and MiniImageNet dataset
327 stars 81 forks source link

Cosine distance calculation problem #8

Open ouc-lq opened 3 years ago

ouc-lq commented 3 years ago

In the source code, the author calculates the cosine distance as follows.

        sum_support = torch.sum(torch.pow(support_image, 2), 1) 
        support_manitude = sum_support.clamp(eps, float("inf")).rsqrt() 
        dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
        cosine_similarity = dot_product * support_manitude * input_manitude

But in my opinion, the right the cosine distance should be calculated as follows.

        sum_support = torch.sum(torch.pow(support_image, 2), 1) 
        support_manitude = sum_support.clamp(eps, float("inf")).rsqrt() 
        sum_input = torch.sum(torch.pow(input_image, 2), 1)
        input_manitude = sum_input.clamp(eps, float("inf")).rsqrt()
        dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
        cosine_similarity = dot_product * support_manitude * input_manitude

Am i right? If not, what is the mistake?

ouc-lq commented 3 years ago

@gitabcworld @jacklanchantin

502dxceit commented 1 year ago

we could calculate cosine similarity with following succinct code: cosine_similarity = F.cosine_similarity(support_image, target_set) where F is imported as torch.nn.functional