pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.93k stars 3.61k forks source link

How can I use the class `MetaLayer` to implement an attention mechanism? #3223

Open dmeoli opened 2 years ago

dmeoli commented 2 years ago

I'm trying to extend a project which implements a GNN using Battaglia et al.'s definition through the MetaLayer class. I would like to include some attention mechanisms as defined here, so how can I use the MetaLayer class to implement this TF code? Or maybe using a GATConv layer?

rusty1s commented 2 years ago

IMO, the easiest way to implement this is to make use of the MessagePassing interface rather than the MetaLayer, in which you return both transformed node features and edge features. Here is a minimal example to allow for attention based on both node features and edge features: https://github.com/pyg-team/pytorch_geometric/discussions/3209#discussioncomment-1373286

If you still wanna use the MetaLayer, then you need to update the NodeModel to compute attention coefficients and apply normalization based on it, e.g.:

from torch_geometric.utils import softmax

class NodeModel(torch.nn.Module):
    def __init__(self):
        super(NodeModel, self).__init__()
        self.att_mlp = # Maps to number_of_heads
        self.transform_mlp = # Maps to number of heads * number of output channels

    def forward(self, x, edge_index, edge_attr, u, batch):
        row, col = edge_index
        out = torch.cat([x[row], edge_attr, x[col]], dim=1)
        att = self.att_mlp(out)
        att = softmax(att, col)  # Normalize across local neighborhoods
        out = self.transform_mlp(out)
        out = out.view(-1, num_heads, num_features) * att.view(-1, num_heads, 1)
        out = out.view(-1, num_heads * num_features)
        out = scatter_add(out, col, dim=0, dim_size=x.size(0))
        return out
dmeoli commented 2 years ago

Actually, my code is:

class ModifiedMetaLayer(MetaLayer):

    def forward(
            self, x, edge_index, edge_attr=None, u=None, v_indices=None, e_indices=None
    ):
        row, col = edge_index

        if self.edge_model is not None:
            edge_attr = self.edge_model(x[row], x[col], edge_attr, u, e_indices)

        if self.node_model is not None:
            x = self.node_model(x, edge_index, edge_attr, u, v_indices)

        if self.global_model is not None:
            u = self.global_model(x, edge_attr, u, v_indices, e_indices)

        return x, edge_attr, u
def get_mlp(
        in_size,
        out_size,
        n_hidden,
        hidden_size,
        activation=ReLU,
        activate_last=True,
        layer_norm=True
):
    arch = []
    l_in = in_size
    for l_idx in range(n_hidden):
        arch.append(Lin(l_in, hidden_size))
        arch.append(activation())
        l_in = hidden_size

    arch.append(Lin(l_in, out_size))

    if activate_last:
        arch.append(activation())

        if layer_norm:
            arch.append(LayerNorm(out_size))

    return Seq(*arch)
class GraphNet:

    def __init__(
            self,
            in_dims,
            out_dims,
            independent=False,
            e2v_agg="sum",
            n_hidden=1,
            hidden_size=64,
            activation=ReLU,
            layer_norm=True
    ):
        if e2v_agg not in ["sum", "mean"]:
            raise ValueError("Unknown aggregation function.")

        v_in = in_dims[0]  # n_node_features_in
        e_in = in_dims[1]  # n_edge_features_in
        u_in = in_dims[2]  # n_global_features_in

        v_out = out_dims[0]  # n_node_features_out
        e_out = out_dims[1]  # n_edge_features_out
        u_out = out_dims[2]  # n_global_features_out

        class EdgeModel(torch.nn.Module):

            def __init__(self):
                super(EdgeModel, self).__init__()

                if independent:
                    self.edge_mlp = get_mlp(
                        e_in,  # n_edge_features_in
                        e_out,  # n_edge_features_out
                        n_hidden,
                        hidden_size,
                        activation=activation,
                        layer_norm=layer_norm
                    )
                else:
                    self.edge_mlp = get_mlp(
                        e_in + 2 * v_in + u_in,  # n_edge_features_in + 2 * n_node_features_in + n_global_features_in
                        e_out,
                        n_hidden,
                        hidden_size,
                        activation=activation,
                        layer_norm=layer_norm
                    )

            def forward(self, src, target, edge_attr, u=None, e_indices=None):
                # src, target: [E, F_x], where E is the number of edges.
                # edge_attr: [E, F_e]
                # u: [B, F_u], where B is the number of graphs.
                # e_indices: [E] with max entry B - 1.
                if independent:
                    return self.edge_mlp(edge_attr)

                out = torch.cat([src, target, edge_attr, u[e_indices]], 1)
                return self.edge_mlp(out)

        class NodeModel(torch.nn.Module):

            def __init__(self):
                super(NodeModel, self).__init__()

                if independent:
                    self.node_mlp = get_mlp(
                        v_in,  # n_node_features_in
                        v_out,  # n_node_features_out
                        n_hidden,
                        hidden_size,
                        activation=activation,
                        layer_norm=layer_norm
                    )
                else:
                    self.node_mlp = get_mlp(
                        v_in + e_out + u_in,  # n_node_features_in + n_edge_features_out + n_global_features_in
                        v_out,  # n_node_features_out
                        n_hidden,
                        hidden_size,
                        activation=activation,
                        layer_norm=layer_norm
                    )

            def forward(self, x, edge_index, edge_attr, u=None, v_indices=None):
                # x: [N, F_x], where N is the number of nodes.
                # edge_index: [2, E] with max entry N - 1.
                # edge_attr: [E, F_e]
                # u: [B, F_u], where B is the number of graphs.
                # v_indices: [N] with max entry B - 1.
                if independent:
                    return self.node_mlp(x)

                row, col = edge_index
                if e2v_agg == "sum":
                    out = scatter_add(edge_attr, row, dim=0, dim_size=x.size(0))
                elif e2v_agg == "mean":
                    out = scatter_mean(edge_attr, row, dim=0, dim_size=x.size(0))
                out = torch.cat([x, out, u[v_indices]], dim=1)
                return self.node_mlp(out)

        class GlobalModel(torch.nn.Module):

            def __init__(self):
                super(GlobalModel, self).__init__()

                if independent:
                    self.global_mlp = get_mlp(
                        u_in,  # n_global_features_in
                        u_out,  # n_global_features_out
                        n_hidden,
                        hidden_size,
                        activation=activation,
                        layer_norm=layer_norm
                    )
                else:
                    self.global_mlp = get_mlp(
                        u_in + v_out + e_out,  # n_global_features_in + n_node_features_out + n_edge_features_out
                        u_out,  # n_global_features_out
                        n_hidden,
                        hidden_size,
                        activation=activation,
                        layer_norm=layer_norm
                    )

            def forward(self, x, edge_attr, u, v_indices, e_indices):
                # x: [N, F_x], where N is the number of nodes.
                # edge_attr: [E, F_e]
                # u: [B, F_u], where B is the number of graphs.
                # e_indices: [E] with max entry B - 1.
                if independent:
                    return self.global_mlp(u)

                out = torch.cat([u,
                                 scatter_mean(x, v_indices, dim=0),
                                 scatter_mean(edge_attr, e_indices, dim=0)], dim=1)
                return self.global_mlp(out)

        self.op = ModifiedMetaLayer(EdgeModel(), NodeModel(), GlobalModel())

    def forward(self, x, edge_index, edge_attr=None, u=None, v_indices=None, e_indices=None):
        return self.op(x, edge_index, edge_attr, u, v_indices, e_indices)

Then, the GraphNet class is used to build an EncoderCoreDecoder architecture as specified in Battaglia et al., i.e.:

class EncoderCoreDecoder:
    """
    Full encode-process-decode model.
    - An "Encoder" graph net, which independently encodes the edge, node, and
      global attributes (does not compute relations etc.).
    - A "Core" graph net, which performs N rounds of processing (message-passing)
      steps. The input to the Core is the concatenation of the Encoder's output
      and the previous output of the Core (labeled "Hidden(t)" below, where "t" is
      the processing step).
    - A "Decoder" graph net, which independently decodes the edge, node, and
      global attributes (does not compute relations etc.), on each
      message-passing step.

                        Hidden(t)   Hidden(t+1)
                           |            ^
              *---------*  |  *------*  |  *---------*
              |         |  |  |      |  |  |         |
    Input --->| Encoder |  *->| Core |--*->| Decoder |---> Output(t)
              |         |---->|      |     |         |
              *---------*     *------*     *---------*
    """

    def __init__(
            self,
            in_dims,
            core_out_dims,
            out_dims,
            core_steps=1,
            encoder_out_dims=None,
            dec_out_dims=None,
            e2v_agg="sum",
            n_hidden=1,
            hidden_size=64,
            activation=ReLU,
            independent_block_layers=1,
            layer_norm=True
    ):
        # all dims are tuples with (v,e) feature sizes
        self.steps = core_steps
        # if dec_out_dims is None, there will not be a decoder
        self.in_dims = in_dims
        self.core_out_dims = core_out_dims
        self.dec_out_dims = dec_out_dims

        self.layer_norm = layer_norm

        self.encoder = None
        if encoder_out_dims is not None:
            self.encoder = GraphNet(
                in_dims,
                encoder_out_dims,
                independent=True,
                n_hidden=independent_block_layers,
                hidden_size=hidden_size,
                activation=activation,
                layer_norm=self.layer_norm
            )

        core_in_dims = in_dims if self.encoder is None else encoder_out_dims

        self.core = GraphNet(
            (
                core_in_dims[0] + core_out_dims[0],
                core_in_dims[1] + core_out_dims[1],
                core_in_dims[2] + core_out_dims[2]
            ),
            core_out_dims,
            e2v_agg=e2v_agg,
            n_hidden=n_hidden,
            hidden_size=hidden_size,
            activation=activation,
            layer_norm=self.layer_norm
        )

        self.decoder = None
        if dec_out_dims is not None:
            self.decoder = GraphNet(
                core_out_dims,
                dec_out_dims,
                independent=True,
                n_hidden=independent_block_layers,
                hidden_size=hidden_size,
                activation=activation,
                layer_norm=self.layer_norm
            )

        pre_out_dims = core_out_dims if self.decoder is None else dec_out_dims

        self.vertex_out_transform = (
            Lin(pre_out_dims[0], out_dims[0]) if out_dims[0] is not None else None
        )
        self.edge_out_transform = (
            Lin(pre_out_dims[1], out_dims[1]) if out_dims[1] is not None else None
        )
        self.global_out_transform = (
            Lin(pre_out_dims[2], out_dims[2]) if out_dims[2] is not None else None
        )

    def get_init_state(self, n_v, n_e, n_u, device):
        return (
            torch.zeros((n_v, self.core_out_dims[0]), device=device),
            torch.zeros((n_e, self.core_out_dims[1]), device=device),
            torch.zeros((n_u, self.core_out_dims[2]), device=device)
        )

    def forward(self, x, edge_index, edge_attr, u, v_indices=None, e_indices=None):
        # if v_indices and e_indices are both None, then we have only one graph without a batch
        if v_indices is None and e_indices is None:
            v_indices = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)
            e_indices = torch.zeros(
                edge_attr.shape[0], dtype=torch.long, device=edge_attr.device
            )

        if self.encoder is not None:
            x, edge_attr, u = self.encoder(
                x, edge_index, edge_attr, u, v_indices, e_indices
            )

        latent0 = (x, edge_attr, u)
        latent = self.get_init_state(
            x.shape[0], edge_attr.shape[0], u.shape[0], x.device
        )
        for st in range(self.steps):
            latent = self.core(
                torch.cat([latent0[0], latent[0]], dim=1),
                edge_index,
                torch.cat([latent0[1], latent[1]], dim=1),
                torch.cat([latent0[2], latent[2]], dim=1),
                v_indices,
                e_indices
            )

        if self.decoder is not None:
            latent = self.decoder(
                latent[0], edge_index, latent[1], latent[2], v_indices, e_indices
            )

        v_out = (
            latent[0]
            if self.vertex_out_transform is None
            else self.vertex_out_transform(latent[0])
        )
        e_out = (
            latent[1]
            if self.edge_out_transform is None
            else self.edge_out_transform(latent[1])
        )
        u_out = (
            latent[2]
            if self.global_out_transform is None
            else self.global_out_transform(latent[2])
        )
        return v_out, e_out, u_out

I have both node and edge features, so (according to this implementation) I need to change the NodeModel and EdgeModel classes to compute attention coefficients.

But how can I rewrite this class using the MessagePassing interface rather than the MetaLayer class in order to add some extra layer (e.g., GATConv)?

Thx

rusty1s commented 2 years ago

The MessagePassing class helps you to perform message passing to compute new node features based on neighboring ones. You can easily extend it to incorporate updating edge representations as well:

class GraphNet(MessagePassing):
    def __init__(self, ...)
        super().__init__(self, aggr='mean')
        self.node_MLP = ...
        self.edge_MLP = ...

    def forward(self, x, edge_index, edge_attr):
         row, col = edge_index
         edge_attr = self.edge_MLP(torch.cat([x[row], x[col], edge_attr], dim=-1)
         x = self.propagate(edge_index, x=x, edge_attr=edge_attr)
         return x, edge_attr

    def message(self, x_i, x_j, edge_attr):
        edge_attr = self.node_MLP(torch.cat([x_i, x_j, edge_attr], dim=-1)