pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.06k stars 3.63k forks source link

Best way to learn adjacency matrix for a graph? #1361

Closed christopher-beckham closed 1 year ago

christopher-beckham commented 4 years ago

❓ Questions & Help

Hi,

Apologies if this has already been posted (though I spent a good half an hour trying to find a question like this). I am trying to figure out what the best way is to learn a parameterisation of a graph (i.e. have a neural net predict from some input: the nodes, their features, and the adjacency matrix).

I see that many of the graph conv layers take in a 2D tensor of edge indices, for edge_index, though we would not be able to backprop through this. It seems like either one would have to (a) define a fully-connected graph and instead infer the edge weights (where a weight of 0 between nodes (i,j) would effectively simulate two nodes not being connected), or if it's possible, directly pass in the adjacency matrix as one dense (n,n) matrix (though I assume this can only be binary, so that may also be problematic).

Any thoughts? Thanks in advance.

rusty1s commented 4 years ago

The general consensus for an Graph-AE is to train against the dense adjacency matrix. However, you only need a dense output. In contrast, the input graph can be sparse. We have an example of this, see examples/autoencoder.py.

Note that, as you correctly mentioned, it is not possible to train against a sparse adjacency matrix. This stems mostly from the fact that you need a fixed output dimension with a fixed ordering, and that requirement cannot be fulfilled by sparse matrices.

However, there is some literature on this topic, e.g., Graph-RNN, which generates graphs in an auto-regressive fashion.

christopher-beckham commented 4 years ago

Hi,

Thanks for your response!

In my case, I'd want to use the inferred outputs in a downstream manner (i.e., both the nodes' features and the adjacency matrix) and have that all be backproppable, e.g.:

input -> [mlp] -> {X, E} -> [GNNs] -> output

where E is the adjacency matrix and X are the node features. I assume that E however needs to be sparse in order for it to work with the GNNs later on in the network

In the case of the autoencoder its output (a dense adjacency matrix) just happens to also be the end of the network, which is convenient. In my case, it still seems like the most plausible option would be to fix the adjacency matrix to have the graph be fully-connected, and instead have the network infer edge weights instead. Let me know if you agree with this line of thinking.

Thanks again!

rusty1s commented 4 years ago

Note that we also provide GNNs that can operate on dense input. For example, this is done in the DiffPool model. An alternative way would be to sparsify your dense adjacency matrix based on a user-defined threshold (similar to a ReLU activation):

edge_index = (adj > 0.5).nonzero().t()
edge_weight = adj[edge_index[0], edge_index[1]]

If you utilize both edge_index and edge_weight in your follow-up GNN, your graph generation is fully-trainable (except for the values you remove).

christopher-beckham commented 4 years ago

Thanks! I will be sure to try it out

rusty1s commented 4 years ago

The output of nonzero() breaks the computation graph, but the actual tensor still requires grad. And it still will when it gets indexed based on the indices returned by nonzero().

LinHeLurking commented 2 years ago

I'm not sure if I understand the question correctly. But I think you do not have to use a dense adjacency matrix as input. Node features themselves are enough to predict edge connectivity (or weights). I did a small experiments these days. And it turns out pairwise concatenation of node features is suitable to generate "edge prediction".

After 2 or 3 epochs of training, the network can learn the exact adjacency matrix.

image

Here is the full code.

import os.path as osp
import random

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as gnn
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch as PygBatch
from torch_geometric.data import Data as PygData

node_num = 20
feat_dim = 8 # The feature dimension cannot be too small.
device = "cuda" if torch.cuda.is_available() else "cpu"

class TestModel(nn.Module):
    def __init__(self, node_num: int, feat_dim: int) -> None:
        super().__init__()
        self._node_num = node_num
        self._feat_dim = feat_dim

        self._x_em = nn.Embedding(num_embeddings=node_num, embedding_dim=feat_dim)

        self._gc_list = nn.ModuleList(
            [
                gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
                gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
                gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
            ]
        )
        self._last_gc = gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim)

        # MLP fits as a "Combiner". Too shallow MLP would give a bad output.
        self._mlp = nn.Sequential(
            nn.Linear(in_features=2 * feat_dim, out_features=2 * feat_dim),
            nn.LeakyReLU(),
            nn.Linear(in_features=2 * feat_dim, out_features=feat_dim),
            nn.LeakyReLU(),
            nn.Linear(in_features=feat_dim, out_features=feat_dim),
            nn.LeakyReLU(),
            nn.Linear(in_features=feat_dim, out_features=1),
        )

    def forward(self, x, edge_index):
        n = self._node_num
        f = self._feat_dim

        x = self._x_em(x)

        # The `+x` part is important. It helps the network capture features during different convolution stages.
        for conv in self._gc_list:
            x = F.leaky_relu(conv(x, edge_index)) + x
        x = self._last_gc(x, edge_index) + x
        x = x.view(-1, n, f)  # [B, N, F]

        # Pairwise concatenation
        idx_pairs = torch.cartesian_prod(
            torch.arange(x.shape[-2]), torch.arange(x.shape[-2])
        )
        x = x[:, idx_pairs]  # [B, N * N, 2, F]
        x = x.view(-1, n, n, 2 * f)
        x = self._mlp(x)  # [B, N, N, 1]
        x = x.view(-1, n, n)
        x = (x + x.transpose(-1, -2)) / 2
        return x

loss_fn = nn.SmoothL1Loss()
net = TestModel(node_num=node_num, feat_dim=feat_dim).to(device=device)

optimizer = optim.RAdam(net.parameters(), lr=0.005)

log_dir = osp.dirname(osp.abspath(__file__))
log_dir = osp.join(log_dir, "torch_runs")
log_dir = osp.join(log_dir, "adj_learner")
summary_writer = SummaryWriter(log_dir=log_dir)

x = torch.arange(node_num)
dataset = []

for _ in range(5000):
    edge_index = []
    adj_mat = torch.zeros(node_num, node_num, dtype=torch.float)
    for _ in range(10):
        u, v = -1, -1
        while u == v:
            u = random.choice(range(node_num))
            v = random.choice(range(node_num))
        edge_index.append((u, v))
        edge_index.append((v, u))
        adj_mat[u][v] = 1.0
        adj_mat[v][u] = 1.0
    edge_index = torch.tensor(edge_index, dtype=torch.long).T.contiguous()
    dataset.append((PygData(x=x, edge_index=edge_index), adj_mat))

test_data = random.sample(dataset, k=2)

def show_status(epoch_id: int = None):
    for tag, (data, adj_mat) in enumerate(test_data):
        batch = PygBatch.from_data_list([data]).to(device)
        with torch.no_grad():
            out = net(batch.x, batch.edge_index)

        out = out.cpu().squeeze()

        fig, (ax1, ax2) = plt.subplots(ncols=2)
        im1 = ax1.matshow(adj_mat, interpolation=None)
        im2 = ax2.matshow(out, interpolation=None)
        ax1.set_title("Adjacency Matrix")
        ax2.set_title("Fitted Matrix")
        fig.colorbar(im1, ax=ax1)
        fig.colorbar(im2, ax=ax2)
        summary_writer.add_figure(f"Example status {tag}", fig, epoch_id)

running_loss = 0.0
obs_period = 200
iter_per_epoch = 2000
for epoch_id in range(100):
    print(f"Epoch :{epoch_id+1}")
    net.train()
    batch_size = 32
    for it in range(iter_per_epoch):
        data_list, adj_mat_batch = zip(*random.sample(dataset, k=batch_size))
        adj_mat_batch = torch.stack(adj_mat_batch).to(device)
        batch = PygBatch.from_data_list(data_list).to(device)

        out = net(batch.x, batch.edge_index)
        optimizer.zero_grad()
        loss = loss_fn(out, adj_mat_batch)
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()
        running_loss += cur_loss

        summary_writer.add_scalar(
            "Training loss", cur_loss, epoch_id * iter_per_epoch + it
        )
        if (it + 1) % obs_period == 0:
            running_loss /= obs_period
            print(f"    [{it+1:4}] running loss: {running_loss:0.4f}")
            running_loss = 0.0
    net.train(False)
    show_status(epoch_id)
dkhonker commented 3 months ago

I'm not sure if I understand the question correctly. But I think you do not have to use a dense adjacency matrix as input. Node features themselves are enough to predict edge connectivity (or weights). I did a small experiments these days. And it turns out pairwise concatenation of node features is suitable to generate "edge prediction".

After 2 or 3 epochs of training, the network can learn the exact adjacency matrix.

image

Here is the full code.

import os.path as osp
import random

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as gnn
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch as PygBatch
from torch_geometric.data import Data as PygData

node_num = 20
feat_dim = 8 # The feature dimension cannot be too small.
device = "cuda" if torch.cuda.is_available() else "cpu"

class TestModel(nn.Module):
    def __init__(self, node_num: int, feat_dim: int) -> None:
        super().__init__()
        self._node_num = node_num
        self._feat_dim = feat_dim

        self._x_em = nn.Embedding(num_embeddings=node_num, embedding_dim=feat_dim)

        self._gc_list = nn.ModuleList(
            [
                gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
                gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
                gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
            ]
        )
        self._last_gc = gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim)

        # MLP fits as a "Combiner". Too shallow MLP would give a bad output.
        self._mlp = nn.Sequential(
            nn.Linear(in_features=2 * feat_dim, out_features=2 * feat_dim),
            nn.LeakyReLU(),
            nn.Linear(in_features=2 * feat_dim, out_features=feat_dim),
            nn.LeakyReLU(),
            nn.Linear(in_features=feat_dim, out_features=feat_dim),
            nn.LeakyReLU(),
            nn.Linear(in_features=feat_dim, out_features=1),
        )

    def forward(self, x, edge_index):
        n = self._node_num
        f = self._feat_dim

        x = self._x_em(x)

        # The `+x` part is important. It helps the network capture features during different convolution stages.
        for conv in self._gc_list:
            x = F.leaky_relu(conv(x, edge_index)) + x
        x = self._last_gc(x, edge_index) + x
        x = x.view(-1, n, f)  # [B, N, F]

        # Pairwise concatenation
        idx_pairs = torch.cartesian_prod(
            torch.arange(x.shape[-2]), torch.arange(x.shape[-2])
        )
        x = x[:, idx_pairs]  # [B, N * N, 2, F]
        x = x.view(-1, n, n, 2 * f)
        x = self._mlp(x)  # [B, N, N, 1]
        x = x.view(-1, n, n)
        x = (x + x.transpose(-1, -2)) / 2
        return x

loss_fn = nn.SmoothL1Loss()
net = TestModel(node_num=node_num, feat_dim=feat_dim).to(device=device)

optimizer = optim.RAdam(net.parameters(), lr=0.005)

log_dir = osp.dirname(osp.abspath(__file__))
log_dir = osp.join(log_dir, "torch_runs")
log_dir = osp.join(log_dir, "adj_learner")
summary_writer = SummaryWriter(log_dir=log_dir)

x = torch.arange(node_num)
dataset = []

for _ in range(5000):
    edge_index = []
    adj_mat = torch.zeros(node_num, node_num, dtype=torch.float)
    for _ in range(10):
        u, v = -1, -1
        while u == v:
            u = random.choice(range(node_num))
            v = random.choice(range(node_num))
        edge_index.append((u, v))
        edge_index.append((v, u))
        adj_mat[u][v] = 1.0
        adj_mat[v][u] = 1.0
    edge_index = torch.tensor(edge_index, dtype=torch.long).T.contiguous()
    dataset.append((PygData(x=x, edge_index=edge_index), adj_mat))

test_data = random.sample(dataset, k=2)

def show_status(epoch_id: int = None):
    for tag, (data, adj_mat) in enumerate(test_data):
        batch = PygBatch.from_data_list([data]).to(device)
        with torch.no_grad():
            out = net(batch.x, batch.edge_index)

        out = out.cpu().squeeze()

        fig, (ax1, ax2) = plt.subplots(ncols=2)
        im1 = ax1.matshow(adj_mat, interpolation=None)
        im2 = ax2.matshow(out, interpolation=None)
        ax1.set_title("Adjacency Matrix")
        ax2.set_title("Fitted Matrix")
        fig.colorbar(im1, ax=ax1)
        fig.colorbar(im2, ax=ax2)
        summary_writer.add_figure(f"Example status {tag}", fig, epoch_id)

running_loss = 0.0
obs_period = 200
iter_per_epoch = 2000
for epoch_id in range(100):
    print(f"Epoch :{epoch_id+1}")
    net.train()
    batch_size = 32
    for it in range(iter_per_epoch):
        data_list, adj_mat_batch = zip(*random.sample(dataset, k=batch_size))
        adj_mat_batch = torch.stack(adj_mat_batch).to(device)
        batch = PygBatch.from_data_list(data_list).to(device)

        out = net(batch.x, batch.edge_index)
        optimizer.zero_grad()
        loss = loss_fn(out, adj_mat_batch)
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()
        running_loss += cur_loss

        summary_writer.add_scalar(
            "Training loss", cur_loss, epoch_id * iter_per_epoch + it
        )
        if (it + 1) % obs_period == 0:
            running_loss /= obs_period
            print(f"    [{it+1:4}] running loss: {running_loss:0.4f}")
            running_loss = 0.0
    net.train(False)
    show_status(epoch_id)

Your node features are directly related to the ordinal, and nn.Embedding, and the adjacency matrix is randomly generated. Why can this represent “Node features themselves are enough to predict edge connectivity”? This result seems to only indicate that the neural network can fit the adjacency matrix, but the node features appear to be completely useless.Because the nn.Embedding is trainable.