Open dmeoli opened 2 years ago
IMO, the easiest way to implement this is to make use of the MessagePassing
interface rather than the MetaLayer
, in which you return both transformed node features and edge features. Here is a minimal example to allow for attention based on both node features and edge features: https://github.com/pyg-team/pytorch_geometric/discussions/3209#discussioncomment-1373286
If you still wanna use the MetaLayer
, then you need to update the NodeModel
to compute attention coefficients and apply normalization based on it, e.g.:
from torch_geometric.utils import softmax
class NodeModel(torch.nn.Module):
def __init__(self):
super(NodeModel, self).__init__()
self.att_mlp = # Maps to number_of_heads
self.transform_mlp = # Maps to number of heads * number of output channels
def forward(self, x, edge_index, edge_attr, u, batch):
row, col = edge_index
out = torch.cat([x[row], edge_attr, x[col]], dim=1)
att = self.att_mlp(out)
att = softmax(att, col) # Normalize across local neighborhoods
out = self.transform_mlp(out)
out = out.view(-1, num_heads, num_features) * att.view(-1, num_heads, 1)
out = out.view(-1, num_heads * num_features)
out = scatter_add(out, col, dim=0, dim_size=x.size(0))
return out
Actually, my code is:
class ModifiedMetaLayer(MetaLayer):
def forward(
self, x, edge_index, edge_attr=None, u=None, v_indices=None, e_indices=None
):
row, col = edge_index
if self.edge_model is not None:
edge_attr = self.edge_model(x[row], x[col], edge_attr, u, e_indices)
if self.node_model is not None:
x = self.node_model(x, edge_index, edge_attr, u, v_indices)
if self.global_model is not None:
u = self.global_model(x, edge_attr, u, v_indices, e_indices)
return x, edge_attr, u
def get_mlp(
in_size,
out_size,
n_hidden,
hidden_size,
activation=ReLU,
activate_last=True,
layer_norm=True
):
arch = []
l_in = in_size
for l_idx in range(n_hidden):
arch.append(Lin(l_in, hidden_size))
arch.append(activation())
l_in = hidden_size
arch.append(Lin(l_in, out_size))
if activate_last:
arch.append(activation())
if layer_norm:
arch.append(LayerNorm(out_size))
return Seq(*arch)
class GraphNet:
def __init__(
self,
in_dims,
out_dims,
independent=False,
e2v_agg="sum",
n_hidden=1,
hidden_size=64,
activation=ReLU,
layer_norm=True
):
if e2v_agg not in ["sum", "mean"]:
raise ValueError("Unknown aggregation function.")
v_in = in_dims[0] # n_node_features_in
e_in = in_dims[1] # n_edge_features_in
u_in = in_dims[2] # n_global_features_in
v_out = out_dims[0] # n_node_features_out
e_out = out_dims[1] # n_edge_features_out
u_out = out_dims[2] # n_global_features_out
class EdgeModel(torch.nn.Module):
def __init__(self):
super(EdgeModel, self).__init__()
if independent:
self.edge_mlp = get_mlp(
e_in, # n_edge_features_in
e_out, # n_edge_features_out
n_hidden,
hidden_size,
activation=activation,
layer_norm=layer_norm
)
else:
self.edge_mlp = get_mlp(
e_in + 2 * v_in + u_in, # n_edge_features_in + 2 * n_node_features_in + n_global_features_in
e_out,
n_hidden,
hidden_size,
activation=activation,
layer_norm=layer_norm
)
def forward(self, src, target, edge_attr, u=None, e_indices=None):
# src, target: [E, F_x], where E is the number of edges.
# edge_attr: [E, F_e]
# u: [B, F_u], where B is the number of graphs.
# e_indices: [E] with max entry B - 1.
if independent:
return self.edge_mlp(edge_attr)
out = torch.cat([src, target, edge_attr, u[e_indices]], 1)
return self.edge_mlp(out)
class NodeModel(torch.nn.Module):
def __init__(self):
super(NodeModel, self).__init__()
if independent:
self.node_mlp = get_mlp(
v_in, # n_node_features_in
v_out, # n_node_features_out
n_hidden,
hidden_size,
activation=activation,
layer_norm=layer_norm
)
else:
self.node_mlp = get_mlp(
v_in + e_out + u_in, # n_node_features_in + n_edge_features_out + n_global_features_in
v_out, # n_node_features_out
n_hidden,
hidden_size,
activation=activation,
layer_norm=layer_norm
)
def forward(self, x, edge_index, edge_attr, u=None, v_indices=None):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u], where B is the number of graphs.
# v_indices: [N] with max entry B - 1.
if independent:
return self.node_mlp(x)
row, col = edge_index
if e2v_agg == "sum":
out = scatter_add(edge_attr, row, dim=0, dim_size=x.size(0))
elif e2v_agg == "mean":
out = scatter_mean(edge_attr, row, dim=0, dim_size=x.size(0))
out = torch.cat([x, out, u[v_indices]], dim=1)
return self.node_mlp(out)
class GlobalModel(torch.nn.Module):
def __init__(self):
super(GlobalModel, self).__init__()
if independent:
self.global_mlp = get_mlp(
u_in, # n_global_features_in
u_out, # n_global_features_out
n_hidden,
hidden_size,
activation=activation,
layer_norm=layer_norm
)
else:
self.global_mlp = get_mlp(
u_in + v_out + e_out, # n_global_features_in + n_node_features_out + n_edge_features_out
u_out, # n_global_features_out
n_hidden,
hidden_size,
activation=activation,
layer_norm=layer_norm
)
def forward(self, x, edge_attr, u, v_indices, e_indices):
# x: [N, F_x], where N is the number of nodes.
# edge_attr: [E, F_e]
# u: [B, F_u], where B is the number of graphs.
# e_indices: [E] with max entry B - 1.
if independent:
return self.global_mlp(u)
out = torch.cat([u,
scatter_mean(x, v_indices, dim=0),
scatter_mean(edge_attr, e_indices, dim=0)], dim=1)
return self.global_mlp(out)
self.op = ModifiedMetaLayer(EdgeModel(), NodeModel(), GlobalModel())
def forward(self, x, edge_index, edge_attr=None, u=None, v_indices=None, e_indices=None):
return self.op(x, edge_index, edge_attr, u, v_indices, e_indices)
Then, the GraphNet
class is used to build an EncoderCoreDecoder
architecture as specified in Battaglia et al., i.e.:
class EncoderCoreDecoder:
"""
Full encode-process-decode model.
- An "Encoder" graph net, which independently encodes the edge, node, and
global attributes (does not compute relations etc.).
- A "Core" graph net, which performs N rounds of processing (message-passing)
steps. The input to the Core is the concatenation of the Encoder's output
and the previous output of the Core (labeled "Hidden(t)" below, where "t" is
the processing step).
- A "Decoder" graph net, which independently decodes the edge, node, and
global attributes (does not compute relations etc.), on each
message-passing step.
Hidden(t) Hidden(t+1)
| ^
*---------* | *------* | *---------*
| | | | | | | |
Input --->| Encoder | *->| Core |--*->| Decoder |---> Output(t)
| |---->| | | |
*---------* *------* *---------*
"""
def __init__(
self,
in_dims,
core_out_dims,
out_dims,
core_steps=1,
encoder_out_dims=None,
dec_out_dims=None,
e2v_agg="sum",
n_hidden=1,
hidden_size=64,
activation=ReLU,
independent_block_layers=1,
layer_norm=True
):
# all dims are tuples with (v,e) feature sizes
self.steps = core_steps
# if dec_out_dims is None, there will not be a decoder
self.in_dims = in_dims
self.core_out_dims = core_out_dims
self.dec_out_dims = dec_out_dims
self.layer_norm = layer_norm
self.encoder = None
if encoder_out_dims is not None:
self.encoder = GraphNet(
in_dims,
encoder_out_dims,
independent=True,
n_hidden=independent_block_layers,
hidden_size=hidden_size,
activation=activation,
layer_norm=self.layer_norm
)
core_in_dims = in_dims if self.encoder is None else encoder_out_dims
self.core = GraphNet(
(
core_in_dims[0] + core_out_dims[0],
core_in_dims[1] + core_out_dims[1],
core_in_dims[2] + core_out_dims[2]
),
core_out_dims,
e2v_agg=e2v_agg,
n_hidden=n_hidden,
hidden_size=hidden_size,
activation=activation,
layer_norm=self.layer_norm
)
self.decoder = None
if dec_out_dims is not None:
self.decoder = GraphNet(
core_out_dims,
dec_out_dims,
independent=True,
n_hidden=independent_block_layers,
hidden_size=hidden_size,
activation=activation,
layer_norm=self.layer_norm
)
pre_out_dims = core_out_dims if self.decoder is None else dec_out_dims
self.vertex_out_transform = (
Lin(pre_out_dims[0], out_dims[0]) if out_dims[0] is not None else None
)
self.edge_out_transform = (
Lin(pre_out_dims[1], out_dims[1]) if out_dims[1] is not None else None
)
self.global_out_transform = (
Lin(pre_out_dims[2], out_dims[2]) if out_dims[2] is not None else None
)
def get_init_state(self, n_v, n_e, n_u, device):
return (
torch.zeros((n_v, self.core_out_dims[0]), device=device),
torch.zeros((n_e, self.core_out_dims[1]), device=device),
torch.zeros((n_u, self.core_out_dims[2]), device=device)
)
def forward(self, x, edge_index, edge_attr, u, v_indices=None, e_indices=None):
# if v_indices and e_indices are both None, then we have only one graph without a batch
if v_indices is None and e_indices is None:
v_indices = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)
e_indices = torch.zeros(
edge_attr.shape[0], dtype=torch.long, device=edge_attr.device
)
if self.encoder is not None:
x, edge_attr, u = self.encoder(
x, edge_index, edge_attr, u, v_indices, e_indices
)
latent0 = (x, edge_attr, u)
latent = self.get_init_state(
x.shape[0], edge_attr.shape[0], u.shape[0], x.device
)
for st in range(self.steps):
latent = self.core(
torch.cat([latent0[0], latent[0]], dim=1),
edge_index,
torch.cat([latent0[1], latent[1]], dim=1),
torch.cat([latent0[2], latent[2]], dim=1),
v_indices,
e_indices
)
if self.decoder is not None:
latent = self.decoder(
latent[0], edge_index, latent[1], latent[2], v_indices, e_indices
)
v_out = (
latent[0]
if self.vertex_out_transform is None
else self.vertex_out_transform(latent[0])
)
e_out = (
latent[1]
if self.edge_out_transform is None
else self.edge_out_transform(latent[1])
)
u_out = (
latent[2]
if self.global_out_transform is None
else self.global_out_transform(latent[2])
)
return v_out, e_out, u_out
I have both node and edge features, so (according to this implementation) I need to change the NodeModel
and EdgeModel
classes to compute attention coefficients.
But how can I rewrite this class using the MessagePassing
interface rather than the MetaLayer
class in order to add some extra layer (e.g., GATConv
)?
Thx
The MessagePassing
class helps you to perform message passing to compute new node features based on neighboring ones. You can easily extend it to incorporate updating edge representations as well:
class GraphNet(MessagePassing):
def __init__(self, ...)
super().__init__(self, aggr='mean')
self.node_MLP = ...
self.edge_MLP = ...
def forward(self, x, edge_index, edge_attr):
row, col = edge_index
edge_attr = self.edge_MLP(torch.cat([x[row], x[col], edge_attr], dim=-1)
x = self.propagate(edge_index, x=x, edge_attr=edge_attr)
return x, edge_attr
def message(self, x_i, x_j, edge_attr):
edge_attr = self.node_MLP(torch.cat([x_i, x_j, edge_attr], dim=-1)
I'm trying to extend a project which implements a GNN using Battaglia et al.'s definition through the
MetaLayer
class. I would like to include some attention mechanisms as defined here, so how can I use theMetaLayer
class to implement this TF code? Or maybe using aGATConv
layer?