yuyangw / MolCLR

Implementation of MolCLR: "Molecular Contrastive Learning of Representations via Graph Neural Networks" in PyG.
MIT License
233 stars 57 forks source link

l_pos and r_pos empty in nt_xent loss? #12

Closed nprasadmm closed 2 years ago

nprasadmm commented 2 years ago

When I run MolCLR, an error seems to stem from the following lines in the forward function:

>>> l_pos = torch.diag(similarity_matrix, self.batch_size)
>>> r_pos = torch.diag(similarity_matrix, -self.batch_size)
>>> positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
RuntimeError: shape '[1024, 1]' is invalid for input of size 0

The similarity matrix appears to be calculated correctly, but l_pos and r_pos are empty tensors when printed out. Would appreciate any guidance here.

nprasadmm commented 2 years ago

Update: Never mind, had the wrong batch size in config.