microsoft / tf2-gnn

TensorFlow 2 library implementing Graph Neural Networks
MIT License
371 stars 73 forks source link

How to add a RGAT layer to a custom keras model? #53

Open Jorvan758 opened 2 years ago

Jorvan758 commented 2 years ago

I have been struggling for a while trying to do this, but I'm still more or less a noob, so I precise your help.

Here is one of my attemps:

from keras import Model, layers

inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 3)),name="Input_X")
inputLayer_A = layers.Input(shape=tuple(tf.TensorShape(dims=(None, 2)) for _ in range(3)),name="Input_A")
rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
                    'message_activation_before_aggregation': False,
                    'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(MessagePassingInput(inputLayer_X, inputLayer_A))
modelo = Model([inputLayer_X, inputLayer_A], rgatLayer_1, name="The_model")

Which returns:


TypeError Traceback (most recent call last)

in () 1 inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 3)),name="Input_X") ----> 2 inputLayer_A = layers.Input(shape=tuple(tf.TensorShape(dims=(None, 2)) for _ in range(3)),name="Input_A") 3 rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10, 4 'message_activation_before_aggregation': False, 5 'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(MessagePassingInput(inputLayer_X, inputLayer_A)) 1 frames /usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs) 65 except Exception as e: # pylint: disable=broad-except 66 filtered_tb = _process_traceback_frames(e.__traceback__) ---> 67 raise e.with_traceback(filtered_tb) from None 68 finally: 69 del filtered_tb /usr/local/lib/python3.7/dist-packages/six.py in raise_from(value, from_value) TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([None, 2])' with type ''

And here's another:

from keras import Model, layers

inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 3)),name="Input_X")
inputLayer_A1 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A1")
inputLayer_A2 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A2")
inputLayer_A3 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A3")
rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
                    'message_activation_before_aggregation': False,
                    'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(MessagePassingInput(inputLayer_X,
                                                                                                               [inputLayer_A1, inputLayer_A2, inputLayer_A3]))
modelo = Model([inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3], rgatLayer_1, name="The_model")

That yields:


TypeError Traceback (most recent call last)

in () 6 'message_activation_before_aggregation': False, 7 'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(MessagePassingInput(inputLayer_X, ----> 8 [inputLayer_A1, inputLayer_A2, inputLayer_A3])) 9 modelo = Model([inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3], rgatLayer_1, name="The_model") 1 frames /usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs) 65 except Exception as e: # pylint: disable=broad-except 66 filtered_tb = _process_traceback_frames(e.__traceback__) ---> 67 raise e.with_traceback(filtered_tb) from None 68 finally: 69 del filtered_tb /usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs) 697 except Exception as e: # pylint:disable=broad-except 698 if hasattr(e, 'ag_error_metadata'): --> 699 raise e.ag_error_metadata.to_exception(e) 700 else: 701 raise TypeError: Exception encountered when calling layer "RGAT_1" (type RGAT). in user code: File "/usr/local/lib/python3.7/dist-packages/tf2_gnn/layers/message_passing/message_passing.py", line 116, in call * messages_per_type = self._calculate_messages_per_type( File "/usr/local/lib/python3.7/dist-packages/tf2_gnn/layers/message_passing/message_passing.py", line 190, in _calculate_messages_per_type * type_to_num_incoming_edges = calculate_type_to_num_incoming_edges( File "/usr/local/lib/python3.7/dist-packages/tf2_gnn/layers/message_passing/message_passing.py", line 256, in calculate_type_to_num_incoming_edges * num_incoming_edges = tf.scatter_nd( TypeError: Value passed to parameter 'indices' has DataType float32 not in list of allowed values: int32, int64 Call arguments received: • inputs=MessagePassingInput(node_embeddings='tf.Tensor(shape=(None, None, 3), dtype=float32)', adjacency_lists=['tf.Tensor(shape=(None, None, 2), dtype=float32)', 'tf.Tensor(shape=(None, None, 2), dtype=float32)', 'tf.Tensor(shape=(None, None, 2), dtype=float32)']) • training=False

I'll keep trying to overcome it (and will update if I do so), but if someone can throw some light on the matter, I would be very grateful 🙏

Jorvan758 commented 2 years ago

I've been studying a fair amount and I think that I'm pretty close to solve it. Right now, I got this to run:

from tf2_gnn.layers.message_passing.rgat import RGAT
from tf2_gnn.layers.message_passing.message_passing import MessagePassingInput
from keras import Model, layers
import tensorflow as tf
inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 7)),name="Input_X")
inputLayer_A1 = layers.Input(shape=tf.TensorShape(dims=2),name="Input_A1", dtype=tf.int32)
inputLayer_A2 = layers.Input(shape=tf.TensorShape(dims=2),name="Input_A2", dtype=tf.int32)
inputLayer_A3 = layers.Input(shape=tf.TensorShape(dims=2),name="Input_A3", dtype=tf.int32)
rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
                    'message_activation_before_aggregation': False,
                    'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(MessagePassingInput(inputLayer_X,
                                                                                                               [inputLayer_A1, inputLayer_A2, inputLayer_A3]))
modelo = Model([inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3], rgatLayer_1, name="The_model")
modelo.summary()

Model: "The_model"


Layer (type) Output Shape Param # Connected to

================================================================================= Input_X (InputLayer) [(None, None, 7)] 0 []

Input_A1 (InputLayer) [(None, 2)] 0 []

Input_A2 (InputLayer) [(None, 2)] 0 []

Input_A3 (InputLayer) [(None, 2)] 0 []

RGAT_1 (RGAT) (None, 10) 270 ['Input_X[0][0]',
'Input_A1[0][0]',
'Input_A2[0][0]',
'Input_A3[0][0]']

================================================================================= Total params: 270 Trainable params: 270 Non-trainable params: 0


While it's usable, it's far from ideal, given that it wouldn't work with multiple graphs at the same time (which is what I need). Of course, I tried expanding the input shape of the adjacency matrixes, but the RGAT layer seems to just be able to work with one graph at a time. Because of that, I'm now searching for a workaround (that at least processes multiple graph sequentially). I'll update as soon as I find it. However, if anyone can help, I would appreciate it 👀

Jorvan758 commented 2 years ago

I think I'm almost there, but it's getting trickier. I have 2 relevant attempts. One is this:

from tf2_gnn.layers.message_passing.rgat import RGAT
from tf2_gnn.layers.message_passing.message_passing import MessagePassingInput
from keras import Model, layers
import tensorflow as tf

inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 7)),name="Input_X")
inputLayer_A1 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A1", dtype=tf.int32)
inputLayer_A2 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A2", dtype=tf.int32)
inputLayer_A3 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A3", dtype=tf.int32)
rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
                    'message_activation_before_aggregation': False,
                    'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")
lambdaLayer_1 = layers.Lambda(lambda x: tf.map_fn(lambda y: rgatLayer_1(MessagePassingInput(y[0],[y[1],y[2],y[3]])),
                                                  (x[0],x[1],x[2],x[3]), dtype=tf.float32),
                              name="Lambda_1")((inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3))
modelo = Model([inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3], lambdaLayer_1, name="The_model")
modelo.summary()

Which returns:

ValueError: Exception encountered when calling layer "Lambda_1" (type Lambda).

The following Variables were created within a Lambda layer (Lambda_1) but are not tracked by said layer: <tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_0/kernel:0' shape=(7, 10) dtype=float32> <tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_0/Edge_attention_parameters_0:0' shape=(5, 4) dtype=float32> <tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_1/kernel:0' shape=(7, 10) dtype=float32> <tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_1/Edge_attention_parameters_1:0' shape=(5, 4) dtype=float32> <tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_2/kernel:0' shape=(7, 10) dtype=float32> <tf.Variable 'Lambda_1/map/while/RGAT_1/edge_type_2/Edge_attention_parameters_2:0' shape=(5, 4) dtype=float32> The layer cannot safely ensure proper Variable reuse across multiple calls, and consquently this behavior is disallowed for safety. Lambda layers are not well suited to stateful computation; instead, writing a subclassed Layer is the recommend way to define layers with Variables.

Call arguments received: • inputs=('tf.Tensor(shape=(None, None, 7), dtype=float32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)') • mask=None • training=None

Tried to create a custom layer, but it's fairly difficult for me, so I'm searching for other options. The second attempt is this one:

from tf2_gnn.layers.message_passing.rgat import RGAT
from tf2_gnn.layers.message_passing.message_passing import MessagePassingInput
from keras import Model, layers
import tensorflow as tf

inputLayer_X = layers.Input(shape=tf.TensorShape(dims=(None, 7)),name="Input_X")
inputLayer_A1 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A1", dtype=tf.int32)
inputLayer_A2 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A2", dtype=tf.int32)
inputLayer_A3 = layers.Input(shape=tf.TensorShape(dims=(None, 2)),name="Input_A3", dtype=tf.int32)

lambdaLayer_1 = layers.Lambda(lambda x: tf.map_fn(lambda y: MessagePassingInput(y[0],[y[1],y[2],y[3]]),
                                                  (x[0],x[1],x[2],x[3]), dtype=MessagePassingInput),
                              name="Lambda_1")((inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3))
rgatLayer_1 = RGAT({'aggregation_function': 'sum', 'hidden_dim': 10,
                    'message_activation_before_aggregation': False,
                    'message_activation_function': 'relu', 'num_heads': 5}, name="RGAT_1")(lambdaLayer_1)
modelo = Model([inputLayer_X, inputLayer_A1, inputLayer_A2, inputLayer_A3], rgatLayer_1, name="The_model")
modelo.summary()

And this gives me:

TypeError: Exception encountered when calling layer "Lambda_1" (type Lambda).

Cannot convert value <class 'tf2_gnn.layers.message_passing.message_passing.MessagePassingInput'> to a TensorFlow DType.

Call arguments received: • inputs=('tf.Tensor(shape=(None, None, 7), dtype=float32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)', 'tf.Tensor(shape=(None, None, 2), dtype=int32)') • mask=None • training=None

I'll keep pushing, but I really hope that someone could lend me a hand 😰

mmjb commented 2 years ago

I'm not famliar with Keras and hence can't really help on that front. However, it seems to me that you're unfamiliar with how batching is usually performed in (sparse) GNN implementations: the idea is to represent a batch of graphs as a single graph of disconnected components. As information is only exchanged along edges, these two views are equivalent.

Suitable code to batch graphs like this can be found in https://github.com/microsoft/tf2-gnn/blob/master/tf2_gnn/data/graph_dataset.py#L192-L246.

Jorvan758 commented 2 years ago

I'll give it a look in the future. For now, I think it will be best that I work on other stuff (I'll update when I find a confirmed solution 👍)