danielegrattarola / spektral

Graph Neural Networks with Keras and Tensorflow 2.
MIT License
2.37k stars 334 forks source link

Create custom Graph Extension Type #420

Open dmadisetti opened 1 year ago

dmadisetti commented 1 year ago

A "Graph" Type is even defined in the example:

This lets the user carry around graph information as a Tensorflow object, and even allow for ops on the object level. I think this would also allow for a batch of graphs, which would be great.

This would be a little nicer than carrying around multiple arrays for adjacency etc...

dmadisetti commented 1 year ago

Nvm, for batching there's an explicit BatchableTypeSpec


but still, the custom extension type still makes sense

There is tf.experimental.BatchableExtensionType

Seems to work pretty straightforward? I'm using this for graphs generated by my pipeline.

danielegrattarola commented 1 year ago


do you have a self-contained example of using this type with Spektral? Or does it require to re-write the layers?


dmadisetti commented 1 year ago

Likely would just be another data mode. This is what I have, but integration into the library would probably look a little different

class WrappedGCN(tf.keras.layers.Layer):
    def __init__(self, features, *args, **kwargs):
        super(WrappedGCN, self).__init__()
        self.features = features
        self.layer = GCNConv(features, *args, **kwargs)

    def hook(self, graph):
        features = graph.features.to_tensor()
        features = tf.reshape(features, (1, -1, features.shape[1]))
        adj = tf.cast(graph.adjacency.to_tensor(), tf.float32)
        adj = tf.reshape(adj, (1, adj.shape[0], adj.shape[1]))
        return tf.RaggedTensor.from_tensor(tf.squeeze(self.layer([features, adj])))

    def __call__(self, graph):
        if isinstance(graph, TensorGraph):
            features = tf.map_fn(self.hook, graph, tf.RaggedTensorSpec(
                shape=(None, self.features), dtype=tf.float32))
            return TensorGraph(
        return self.layer(graph)

x0 = TensorGraph(adjacency=tf.ragged.stack(adjs),

x1 = WrappedGCN(6)(x0)
x2 = WrappedGCN(6)(x1)


class TensorGraph(tf.experimental.BatchableExtensionType):
    """A collection of nodes with associated feature vectors."""
    features: tf.RaggedTensor
    adjacency: tf.RaggedTensor

    # TODO: Validation functions etc...

Could probably use Sparse instead of Ragged, but using Ragged here because I need the dense Adjs.