swuxyj / DeepHash-pytorch

Implementation of Some Deep Hash Algorithms, Including DPSH、DSH、DHN、HashNet、DSDH、DTSH、DFH、GreedyHash、CSQ.
MIT License
495 stars 116 forks source link

pairwise and triplet data preparation #23

Closed yosajka closed 2 years ago

yosajka commented 2 years ago

Hi, thank you for your awesome work. I'm learning pytorch so it's little hard for me to understand your code. How do you prepare pairwise or triplet data and feed them into the model in the training phase?

swuxyj commented 2 years ago

torch.unsqueeze:

Example:

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])

pairwise

dist = (u.unsqueeze(1) - self.U.unsqueeze(0)).pow(2).sum(dim=2)

triplet

loss1 = 0
for row in range(s.shape[0]):
    # if has positive pairs and negative pairs
    if s[row].sum() != 0 and (~s[row]).sum() != 0:
        count += 1
        theta_positive = inner_product[row][s[row] == 1]
        theta_negative = inner_product[row][s[row] == 0]
        triple = (theta_positive.unsqueeze(1) - theta_negative.unsqueeze(0) - config["alpha"]).clamp(min=-100,
                                                                                                     max=50)
        loss1 += -(triple - torch.log(1 + torch.exp(triple))).mean()
yosajka commented 2 years ago

Thanks for your quick reply. Helped me a lot. Cheer!