google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

How to get and set the weights of neural networks in GNN #43

Closed TingtingYuan closed 5 years ago

TingtingYuan commented 5 years ago

Such as this. How to get the weights of edge_model_fn after training or creation?

import graph_nets as gn
import sonnet as snt

# Provide your own functions to generate graph-structured data.
input_graphs = get_graphs()

# Create the graph network.
graph_net_module = gn.modules.GraphNetwork(
    edge_model_fn=lambda: snt.nets.MLP([32, 32]),
    node_model_fn=lambda: snt.nets.MLP([32, 32]),
    global_model_fn=lambda: snt.nets.MLP([32, 32]))

# Pass the input graphs to the graph network, and return the output graphs.
output_graphs = graph_net_module(input_graphs)
alvarosg commented 5 years ago

After you connect the module for the first time: output_graphs = graph_net_module(input_graphs)

Then you can use graph_net_module.get_variables() to recover all of the variables containing the weights, including those of the edge model.

alvarosg commented 5 years ago

Closing due to inactivity.

Chris6212 commented 4 years ago

After you connect the module for the first time: output_graphs = graph_net_module(input_graphs)

Then you can use graph_net_module.get_variables() to recover all of the variables containing the weights, including those of the edge model.

Hey, thanks for your cool library! I have a question regarding the graph_net weights. I've copy-pasted the example posted above and presented an own graphTuple as input. When trying to extract the weights as described above, the following error occurs:

AttributeError: 'GraphNetwork' object has no attribute 'get_variables'

How do I get the weights?

zafarali commented 4 years ago

Hi, are you using the tensorflow 2 version? is so you can access the variables using: graph_net_module.trainable_variables after you connect for the first time.

You can access all variables using graph_net_module.variables.

Chris6212 commented 4 years ago

Hi, are you using the tensorflow 2 version? is so you can access the variables using: graph_net_module.trainable_variables after you connect for the first time.

You can access all variables using graph_net_module.variables.

Yes, I am using tensorflow 2 and your proposed solution solved my issue. Thank you for the fast and helpful reply!