Closed SidneyLann closed 2 years ago
Hey! Have you been able to test by replacing SimpleConvolution
with a GAT layer here?
@SidneyLann, yeah this seems to train:
def OgbnArxivModel(
num_classes: int,
num_neighbourhoods: int,
type_spec: tfgnn.GraphTensorSpec,
) -> tf.keras.Model:
input = tf.keras.layers.Input(type_spec=type_spec)
gnn = tfgnn.keras.ConvGNNBuilder(
lambda edge: tfgnn.keras.layers.GATv2Convolution(
num_heads=4,
per_head_channels=16,
receiver_tag=tfgnn.TARGET
),
lambda node: tfgnn.keras.layers.NextStateFromConcat(
tf.keras.layers.Dense(32)
),
)
hidden = gnn.Convolve()(input)
for _ in range(num_neighbourhoods - 1):
hidden = gnn.Convolve()(hidden)
hidden = tfgnn.keras.layers.ReadoutFirstNode(node_set_name="paper")(hidden)
output = tf.keras.layers.Dense(num_classes, activation="softmax")(hidden)
return model.NodeClassificationModel(
input, output, target_node="paper", label_name="label"
)
Thanks.
What's the purpose of GATv2GraphUpdate? How to include it to the model?
I've only really interacted with the ConvGNNBuilder layer but it looks like you can use GraphUpdate layers independently. Here's a test using GATv2GraphUpdate that seems to do that. Let me know if you get that to work.
done
Could you please add a training model for GAT also?