microsoft / tf2-gnn

TensorFlow 2 library implementing Graph Neural Networks
MIT License
371 stars 73 forks source link

Performance issue in tf2_gnn/layers/message_passing/message_passing.py #50

Open JamesCao2048 opened 3 years ago

JamesCao2048 commented 3 years ago

Hello, I found that in the function calculate_type_to_num_incoming_edges, tf2_gnn/layers/message_passing/message_passing.py, tf.shapewill be called redundantly in the iteration to get the same value, code here. Moreover, if users try to add @tf.funtion annotation to speed up this function in graph mode, lots of same nodes will be created in computation graph. Thus, I think tf.shape should be called only once before the loop.

Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.