danielegrattarola / spektral

Graph Neural Networks with Keras and Tensorflow 2.
https://graphneural.network
MIT License
2.37k stars 334 forks source link

Create custom Graph Extension Type #420

Open dmadisetti opened 1 year ago

dmadisetti commented 1 year ago

A "Graph" Type is even defined in the example:

This lets the user carry around graph information as a Tensorflow object, and even allow for ops on the object level. I think this would also allow for a batch of graphs, which would be great.

This would be a little nicer than carrying around multiple arrays for adjacency etc...

dmadisetti commented 1 year ago

Nvm, for batching there's an explicit BatchableTypeSpec

https://github.com/tensorflow/tensorflow/blob/d5b57ca93e506df258271ea00fc29cf98383a374/tensorflow/python/framework/type_spec.py#L738-L751

but still, the custom extension type still makes sense

There is tf.experimental.BatchableExtensionType

Seems to work pretty straightforward? I'm using this for graphs generated by my pipeline.

danielegrattarola commented 1 year ago

Hey,

do you have a self-contained example of using this type with Spektral? Or does it require to re-write the layers?

Cheers

dmadisetti commented 1 year ago

Likely would just be another data mode. This is what I have, but integration into the library would probably look a little different

class WrappedGCN(tf.keras.layers.Layer):
    def __init__(self, features, *args, **kwargs):
        super(WrappedGCN, self).__init__()
        self.features = features
        self.layer = GCNConv(features, *args, **kwargs)

    def hook(self, graph):
        features = graph.features.to_tensor()
        features = tf.reshape(features, (1, -1, features.shape[1]))
        adj = tf.cast(graph.adjacency.to_tensor(), tf.float32)
        adj = tf.reshape(adj, (1, adj.shape[0], adj.shape[1]))
        return tf.RaggedTensor.from_tensor(tf.squeeze(self.layer([features, adj])))

    def __call__(self, graph):
        if isinstance(graph, TensorGraph):
            features = tf.map_fn(self.hook, graph, tf.RaggedTensorSpec(
                shape=(None, self.features), dtype=tf.float32))
            return TensorGraph(
                features=features,
                adjacency=graph.adjacency)
        return self.layer(graph)

x0 = TensorGraph(adjacency=tf.ragged.stack(adjs),
                    features=tf.ragged.stack(features))

x1 = WrappedGCN(6)(x0)
x2 = WrappedGCN(6)(x1)

where

class TensorGraph(tf.experimental.BatchableExtensionType):
    """A collection of nodes with associated feature vectors."""
    features: tf.RaggedTensor
    adjacency: tf.RaggedTensor

    # TODO: Validation functions etc...

Could probably use Sparse instead of Ragged, but using Ragged here because I need the dense Adjs.