BNN-UPC / ignnition

Framework for fast prototyping of Graph Neural Networks
Apache License 2.0
51 stars 16 forks source link

Can I use two readout functions for two different nodes and get a loss from these two functions? #45

Closed YangWang92 closed 3 years ago

YangWang92 commented 3 years ago

Thanks for sharing such a helpful platform that frees us from the complexity of implementation.

Can I use two readout functions for two different nodes and get a loss from these two functions? For example, in RouteNet, can I read link and path states with two readout functions?

Thanks a lot!

Yang

jsuarezv commented 3 years ago

Dear Yang, Thank you for your interest in IGNNITION. What you mention is possible as long as you use the same loss function across all the outputs. For this, you can first define two separate readout functions for links and paths state respectively, and then concatenate the outputs. Then, the loss function, defined in the “train_options.yaml”, is directly applied over the concatenated readout outputs. I leave below a code snippet for the case you mention of RouteNet:

readout:
- type: neural_network
 input: [path]
 nn_name: readout_paths
output_name: out_paths
- type: neural_network
 input: [link]
 nn_name: readout_links
output_name: out_links
- type: concat
 input: [out_paths,out_links]
 output_label: [$delay,$bandwidth]

Please, note in this case you would need to define the “readout_paths” and “readout_links” neural networks. Also, in the final "concat" operation the arguments of "input" and "output_label" should follow the same order. For example, first path-level and then link-level features (i.e., input=[out_path,out_link]; output_label=[$delay,$bandwidth]), where "$delay" is the per-path delay, and "$bandwidth" is the link-level aggregated traffic. Note that the framework does not allow to define a denormalization function in case the “concat” operation is used in the readout, as it typically would require to define one function per output type. One alternative would be to re-generate the datasets with the output labels already normalized (e.g., $delay,$bandwidth), so the model learns to produce these normalized values, and applies the loss function directly over them. Then, you could apply denormalization on the model’s predictions, to obtain the final labels. Please, note the “concat” feature is only supported from the latest version (v1.1.0). We are working to add support for differentiated loss and denormalization functions on multiple outputs in future releases.

Regards, José

YangWang92 commented 3 years ago

Thanks a lot! I'm trying to apply your solution to my GNN model.