Open dmadisetti opened 1 year ago
Nvm, for batching there's an explicit BatchableTypeSpec
but still, the custom extension type still makes sense
There is tf.experimental.BatchableExtensionType
Seems to work pretty straightforward? I'm using this for graphs generated by my pipeline.
Hey,
do you have a self-contained example of using this type with Spektral? Or does it require to re-write the layers?
Cheers
Likely would just be another data mode. This is what I have, but integration into the library would probably look a little different
class WrappedGCN(tf.keras.layers.Layer):
def __init__(self, features, *args, **kwargs):
super(WrappedGCN, self).__init__()
self.features = features
self.layer = GCNConv(features, *args, **kwargs)
def hook(self, graph):
features = graph.features.to_tensor()
features = tf.reshape(features, (1, -1, features.shape[1]))
adj = tf.cast(graph.adjacency.to_tensor(), tf.float32)
adj = tf.reshape(adj, (1, adj.shape[0], adj.shape[1]))
return tf.RaggedTensor.from_tensor(tf.squeeze(self.layer([features, adj])))
def __call__(self, graph):
if isinstance(graph, TensorGraph):
features = tf.map_fn(self.hook, graph, tf.RaggedTensorSpec(
shape=(None, self.features), dtype=tf.float32))
return TensorGraph(
features=features,
adjacency=graph.adjacency)
return self.layer(graph)
x0 = TensorGraph(adjacency=tf.ragged.stack(adjs),
features=tf.ragged.stack(features))
x1 = WrappedGCN(6)(x0)
x2 = WrappedGCN(6)(x1)
where
class TensorGraph(tf.experimental.BatchableExtensionType):
"""A collection of nodes with associated feature vectors."""
features: tf.RaggedTensor
adjacency: tf.RaggedTensor
# TODO: Validation functions etc...
Could probably use Sparse instead of Ragged, but using Ragged here because I need the dense Adjs.
A "Graph" Type is even defined in the example:
This lets the user carry around graph information as a Tensorflow object, and even allow for ops on the object level. I think this would also allow for a batch of graphs, which would be great.
This would be a little nicer than carrying around multiple arrays for adjacency etc...