zhao-tong / GAug

AAAI'21: Data Augmentation for Graph Neural Networks
MIT License
185 stars 33 forks source link

Possible Bug with GAugO-GAT: Gradients cannot backprop through indexing #6

Open LeslieHoloway opened 2 years ago

LeslieHoloway commented 2 years ago

As described in your paper, the Gumbel Softmax trick is applied to enable the back-propagation. The gradients must go through the sampled adjacency matrix to the edge predictor. However, your implementation of GATLayer stops the gradients flow back to the edge predictor while the GCNLayer or GCNLayer is ok. If the gradients can not flow back through the adjacency matrix, the Gumbel Softmax trick is meaningless. I am not sure if I understand it right.

The key difference is how to aggregate node features. Use matrix multiplication or indexing?

In Line 703 of GAug.py, you use indices to pick up the corresponding node feature. The indices coming from nonzero function is not differentiable.

class GATLayer(nn.Module):
    def forward(self, adj, h):
        if self.dropout:
            h = self.dropout(h)
        x = h @ self.W # torch.Size([2708, 128])
        # calculate attentions, both el and er are n_nodes by n_heads
        el = self.attn_l(x)
        er = self.attn_r(x) # torch.Size([2708, 8])
        if isinstance(adj, torch.sparse.FloatTensor):
            nz_indices = adj._indices()
        else:
            nz_indices = adj.nonzero().T
        attn = el[nz_indices[0]] + er[nz_indices[1]] # torch.Size([13264, 8])
        attn = F.leaky_relu(attn, negative_slope=0.2).squeeze()

Here is a toy example.

# generate some features
X = torch.randn((3, 7), requires_grad=True)
# generate adj matrix
adj = torch.randint(2, (3,3), dtype=torch.float, requires_grad=True)

# use nonzero to get idx of source node and target node
nnz = adj.nonzero().T
# get source node and target node feature via indexing
source, target = X[nnz[0]], X[nnz[1]]

# compute an example loss function
loss = (source + target).mean()
loss.backward()

print(source, target)
print(adj.grad)

The output of the last line is None.

The gradients can flow through the adjacency matrix via matrix multiplication, as the implementation of GCNLayer.

Here is a toy example.

# generate some features
X = torch.randn((3, 7), requires_grad=True)
# generate adj matrix
adj = torch.randint(2, (3,3), dtype=torch.float, requires_grad=True)

# use matrix multiplication to get node feature
source = adj @ X

# compute an example loss function
loss = (source).mean()
loss.backward()

print(source, target)
print(adj.grad)

At this time, the grad of adj is correctly computed.

I wonder how to aggregate source and target node features in GAT while making gradients flow through the sampled adjacency matrix.

zhao-tong commented 2 years ago

Hi, thanks for pointing this out! This looks very interesting. I need to do some testing on it to double check on the problem and also see if I can solve it. I'll get back to this issue after that.