google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

Passing a training flag into EncodeProcessDecode model #51

Closed ferreirafabio closed 5 years ago

ferreirafabio commented 5 years ago

So far I have failed to adapt the EncodeProcessDecode module s.t. training and testing phases can be distinguished (e.g. for the purpose of deactivating the adding of noise on latent reps during test time).

This is mainly due to the fact that a "is_training" boolean flag is typically passed when encoder, core and decoder have been already initialized (but not built) and I want to have the same EncodeProcessDecode model for train/test (i.e. no two separate train/test model objects but rather two train test tf ops). Therefore my understanding is that I somehow must be able to pass this flag into EncodeProcessDecode's _build() function. However, this function calls modules.GraphIndependent and here the number of arguments passed is pre-defined (naimly inputs and num_processing_steps)

Since this seems like a common use case to me, I was wondering how folks at DeepMind deal with this. Of course I could change the sonnet interface of GraphIndependent but that does not seem like a scalable solution to me. I could also not find a particular example for this in the documentation. Could you perhaps provide a minimum working example for this use case or provide some ideas how to do this?

Thank you!

ferreirafabio commented 5 years ago

Example:

def _build(self, input_op, num_processing_steps, is_training):
    latent = self._encoder(input_op, is_training)

--> fails with TypeError: _build() takes 2 positional arguments but 3 were given.

While _encoder is:

self._encoder = MLPGraphIndependent()

and MLPGraphIndependent is taken from the example and inherits snt.AbstractModule

vbapst commented 5 years ago

Hi Fabio,

thanks for your interest in your library. Unfortunately the matter is a bit complicated by the fact that the GraphNetwork (or GraphIndependent) _build's method will be calling _build for different sonnet modules, that will in principle accept or not different arguments. This is a similar design decision as the one taken in snt.Sequential (see its docstring).

If you only care about the GraphIndependent case, one workaround would be to bypass it and update the edges and nodes explicitely:

outputs_test = inputs.replace(
    nodes=your_node_module(inputs.nodes, False),
    edges=your_edge_module(inputs.edges, False)
)
outputs_train = inputs.replace(
    nodes=your_node_module(inputs.nodes, True),
    edges=your_edge_module(inputs.edges, True)
)

Matters are admittedly more complicated for the GraphNetwork -- and we are still thinking about what the best way to support this would be.

Finally, let me mention that a slightly more complicated method would be to construct your module outside of the GraphNet and wrap it's build method, something like:

with tf.variable_scope("graph_modules"):

your_edge_module = .. # Define your module here your_node_module = .. your_global_module = ..

outputs_train = modules.GraphNetwork( edge_model_fn=lambda: lambda x: your_edge_module(x, is_training=True), node_model_fn=lambda: lambda x: your_node_module(x, is_training=True), global_model_fn=lambda: lambda x: your_node_module(x, is_training=True) )(inputs) outputs_test = modules.GraphNetwork( edge_model_fn=lambda: lambda x: your_edge_module(x, is_training=False), node_model_fn=lambda: lambda x: your_node_module(x, is_training=False), global_model_fn=lambda: lambda x: your_node_module(x, is_training=False) )(inputs)

Best, Victor

Le jeu. 7 mars 2019 à 01:56, Fábio Ferreira notifications@github.com a écrit :

Example:

def _build(self, input_op, num_processing_steps, is_training): latent = self._encoder(input_op, is_training)

--> fails with TypeError: _build() takes 2 positional arguments but 3 were given.

While _encoder is:

self._encoder = MLPGraphIndependent()

and MLPGraphIndependent is taken from the example and inherits snt.AbstractModule

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/deepmind/graph_nets/issues/51#issuecomment-470351701, or mute the thread https://github.com/notifications/unsubscribe-auth/AGPTThkuWyVHJShXCbhIrjLr_j0rcOshks5vUHGzgaJpZM4biKRG .

ferreirafabio commented 5 years ago

Thank you for coming back to me. Unfortunately, I do care about both, GraphIndependent and GraphNetwork.

Would something like this work? I construct operations for a train and test graph as follows:

# train
self.model.is_training = True
self.model.output_ops_train = self.model(self.model.input_ph, self.config.n_rollouts, self.model.is_training, self.sess)
# test
self.model.is_training = False
self.model.output_ops_test = self.model(self.model.input_ph_test, self.config.n_rollouts, self.model.is_training, self.sess)

And assign the corresponding values to the is_training variable (created with tf.get_variable() in every encoder/decoder model) within the_build() function to modify the values in the TF graphs:

def _build(self, inputs, is_training, sess):
        out = self._network(inputs)

        # modify -is_training- flags accordingly
        with sess.as_default():
            for v in self._network.get_all_variables(collection=tf.GraphKeys.GLOBAL_VARIABLES):
                if "is_training" in v.name:
                    assign_op = v.assign(is_training)
                    sess.run(assign_op)
                    assert v.eval() == is_training

            # check if it is necessary to call _network(inputs) again
            variables = out[0].graph.get_collection("variables")
            for v in variables:
                if "is_training" in v.name:
                    assert v.eval() == is_training

        return out

As long as I maintain two different operation sets / GraphTuples lists (self.model.output_ops_test and self.model.output_ops_train) I believe this should work. What do you think?

Also, since the GraphTuples list are quite large, how can I check if these flags were correctly set?

Edited: your solutions works well for GraphIndependent modules. I would still love to see how you check if it's actually setting the flags correctly.

vbapst commented 5 years ago

Hi Fabio,

I think the second solution I describe is exactly what you want and it should work with both GraphNetwork and a GraphIndependent ?

ferreirafabio commented 5 years ago
outputs_test = inputs.replace(
    nodes=your_node_module(inputs.nodes, False),
    edges=your_edge_module(inputs.edges, False)
)
outputs_train = inputs.replace(
    nodes=your_node_module(inputs.nodes, True),
    edges=your_edge_module(inputs.edges, True)
)

--> just to be clear, is this done in the _build() or init() call? Also, does this replace command need to be in a with ._enter_variable_scope(): block or is it enough if the module contains it?

ferreirafabio commented 5 years ago

It is further unclear to me, how I should call the replace function, wenn the node module itself initiates a class like this:

def __init__(self, model_id, is_training, name="a):
    with self._enter_variable_scope():
            visual_encoder = get_model_from_config(self.model_id, model_type="visual_encoder")(is_training=self.is_training, name="visual_encoder")
            self._network = modules.GraphIndependent(
                ...
                nodes=lambda: get_model_from_config(self.model_id, model_type="visual_and_latent_encoder")(visual_encoder, name="visual_and_latent_node_encoder")
                ...
             )

How do I pass the flag in a replace statement in this case? A little bit more details would be helpful. Or a self-contained minimum example. Thank you

ferreirafabio commented 5 years ago

@vbapst can you comment on this?

vbapst commented 5 years ago

If you go down the first option, then wherever you build your network (probably in your main training loop, or maybe in the __init__ of you module), then you would construct outputs_test and outputs_train as described. You need to enter a variable scope only in the later case. Note that in this case you don't call modules.GraphIndependent, but directly call replace on the nodes, edges and globals of your graph.

I am not sure to understand what you last comment is trying to achieve as we don't usually pass the is_training flag at build time, but at init time. So it would look more like that:

def __init__(self):
  with tf.variable_scope("graph_modules"):
    self._edge_module = ..  # Define your module here. Their `_build` method should take an extra is_training argument
    self._node_module = ..
    self._global_module = ..

def _build(self, is_training)
  return modules.GraphNetwork(
  edge_model_fn=lambda: lambda x: self._edge_module(x, is_training=is_training),
  node_model_fn=lambda: lambda x: self._node_module(x, is_training=is_training),
  global_model_fn=lambda: lambda x: self._global_module(x, is_training=is_training)
)(inputs)
ferreirafabio commented 5 years ago

Thank you for your reply @vbapst. I meant the modules.GraphIndependent. For passing the is_training flag I now simply used the _build() functions of my models and omitted the _init() function like so:

    def __init__(self, name="Encoder"):
        super(Encoder, self).__init__(name=name)

    def _build(self, inputs, is_training, verbose=VERBOSITY):

        self._network = modules.GraphIndependent(
            edge_model_fn=lambda: ... (x, is_training=is_training)
            node_model_fn=lambda: ... (x, is_training=is_training)
            global_model_fn=lambda: ... (x, is_training=is_training)
            )
        return self._network(inputs)

My sanity check shows that the flags are set correctly. Does this also work in your opinion or can this have negative (that I currently do not foresee) consequences?

vbapst commented 5 years ago

This looks good -- just check the variables are correctly shared

ferreirafabio commented 5 years ago

Thanks. Is there a way to check this within the GN framework?

ferreirafabio commented 5 years ago

@vbapst

vbapst commented 5 years ago

You can just list the variables with tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) and check that there are not duplicated (i.e. you should only see one edge_module etc)

ferreirafabio commented 5 years ago

thanks! works!