benedekrozemberczki / pytorch_geometric_temporal

PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models (CIKM 2021)
MIT License
2.61k stars 367 forks source link

Transpose weights bug in EvolveGCN #205

Closed euxhenh closed 1 year ago

euxhenh commented 1 year ago

I believe there is a bug in the following line.

https://github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/e43e95d726174574eeae32ad8c4670631fd0a58d/torch_geometric_temporal/nn/recurrent/evolvegcnh.py#L98

The evolved W that comes from the GRU unit has the same shape as X which is (nodes, features). In the GConv layer, we are passing W and performing a multiplication X@W but we should be performing X@W.T for the shapes to match. It happens that here W is square (due to TopKPooling) which is why it might have been missed. Therefore, I believe we need to pass W.squeeze(dim=0).T. Same goes for the O-version of EvolveGCN.

Could someone verify this please.

Thank you.