ixaxaar / pytorch-dnc

Differentiable Neural Computers, Sparse Access Memory and Sparse Differentiable Neural Computers, for Pytorch
MIT License
335 stars 56 forks source link

bug in cosine distance? #56

Closed rfeinman closed 3 years ago

rfeinman commented 3 years ago

I believe there's a bug in the function from dnc.util for computing cosine distance. First, I think you are trying to compute cosine similarity, not distance (sim = 1 - dist). Second, I think the current function implements neither cosine similarity nor distance. Here's a modified variant that returns the correct output for cosine similarity.

def bcos(a, b, normBy=2):
    """Batchwise cosine similarity

    Arguments:
        a: 3D tensor of shape [b,m,w]
        b: 3D tensor of shape [b,r,w]
    Returns:
        cos: batchwise cosine similarity of shape [b,r,m]
    """
    dot = torch.bmm(a, b.transpose(1,2)) # [b,m,w] @ [b,w,r] -> [b,m,r]
    a_norm = torch.norm(a, normBy, dim=2).unsqueeze(2) # [b,m,1]
    b_norm = torch.norm(b, normBy, dim=2).unsqueeze(1) # [b,1,r]
    cos = dot / (a_norm * b_norm) # [b,m,r]

    return cos.transpose(1,2)  # [b,r,m]
ixaxaar commented 3 years ago

Correct, the intention was to compute similarity. I think you're right, the implementation in the repo is wrong :sob: . Would you care to submit a PR?

rfeinman commented 3 years ago

Sure thing! It's submitted. See https://github.com/ixaxaar/pytorch-dnc/pull/57