pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.23k stars 3.64k forks source link

Gradient for bipartite graphs behaves odd for GAT(v2)Conv #7668

Closed mova closed 1 year ago

mova commented 1 year ago

🐛 Describe the bug

import torch
from torch_geometric.nn import conv
in_channels=4
n_heads=4

sage = conv.SAGEConv(
    in_channels=in_channels * n_heads,
    out_channels=in_channels,
    concat=True,
)
gat2 = conv.GATv2Conv(
    in_channels=in_channels * n_heads,
    out_channels=in_channels,
    heads=n_heads,
    concat=True,
)
gat = conv.GATConv(
    in_channels=in_channels * n_heads,
    out_channels=in_channels,
    heads=n_heads,
    concat=True,
)
x = torch.rand(2,in_channels*n_heads).float()
x.requires_grad =True
x_aggrs = torch.rand(3,in_channels*n_heads).float()

x.retain_grad()
for mpl in [sage,gat,gat2]:
    print(mpl)
    for ix in range(len(x)):
        for iaggr in range(len(x_aggrs)):
            ei_o2c = torch.tensor([[ix], [iaggr]]).long().to(x.device)
            xcent = mpl(
                x=(x, x_aggrs),
                edge_index=ei_o2c,
                # size=(batch_size, ratio * batch_size),
            )
            xcent[iaggr].sum().backward(retain_graph=True)
            print(
                f"Connection {ix} -> {iaggr} yields"
                f" {x.grad.sum(1).nonzero().squeeze()}"
            )
            x.grad = torch.zeros_like(x.grad)

Result:

SAGEConv(16, 4, aggr=mean)
Connection 0 -> 0 yields 0
Connection 0 -> 1 yields 0
Connection 0 -> 2 yields 0
Connection 1 -> 0 yields 1
Connection 1 -> 1 yields 1
Connection 1 -> 2 yields 1
GATConv(16, 4, heads=4)
Connection 0 -> 0 yields 0
Connection 0 -> 1 yields tensor([0, 1])
Connection 0 -> 2 yields 0
Connection 1 -> 0 yields tensor([0, 1])
Connection 1 -> 1 yields 1
Connection 1 -> 2 yields 1
GATv2Conv(16, 4, heads=4)
Connection 0 -> 0 yields 0
Connection 0 -> 1 yields tensor([0, 1])
Connection 0 -> 2 yields 0
Connection 1 -> 0 yields tensor([0, 1])
Connection 1 -> 1 yields 1
Connection 1 -> 2 yields 1

Environment

mova commented 1 year ago

I just found out that this is the same issue as https://github.com/pyg-team/pytorch_geometric/issues/7581 If I pass add_self_loops=False the problem goes away.

Anyway, this might be useful code for testing bipartite graphs

rusty1s commented 1 year ago

Super, does this mean this issue is resolved for you?

mova commented 1 year ago

I thought I keep it open while #7581 is open, so people can find it. Feel free to close it.

rusty1s commented 1 year ago

The reference should still work, even it it is closed. I prefer to only keep one issue open to avoid duplicates :)