PetarV- / DGI

Deep Graph Infomax (https://arxiv.org/abs/1809.10341)
MIT License
630 stars 135 forks source link

Error in AvgReadout #8

Open MarcCote opened 4 years ago

MarcCote commented 4 years ago

Hi, I think there's an error in AvgReadout with a mask. The mask summation should be performed along the first dimension only. https://github.com/PetarV-/DGI/blob/0afce4e36b5edbe1e735536d15b748d0381e4083/layers/readout.py#L15

It is return torch.sum(seq * msk, 1) / torch.sum(msk) but should be return torch.sum(seq * msk, 1) / torch.sum(msk, 1) # Note the dimension here.