google-deepmind / sonnet

TensorFlow-based neural network library
https://sonnet.dev/
Apache License 2.0
9.75k stars 1.29k forks source link

Indicating training/testing modes in sonnet callbacks #103

Closed ferreirafabio closed 5 years ago

ferreirafabio commented 5 years ago

What is a convenient way of providing boolean training flags, e.g. is_training that indicate, for example, whether to use batch_norm or not when using sonnet callback functions?

Example:

def make_transpose_cnn_model():
    def transpose_convnet1d(inputs):
        inputs = tf.expand_dims(inputs, axis=2)

        outputs = snt.Conv1DTranspose(output_channels=2, kernel_shape=10, stride=1)(inputs)
        outputs = snt.BatchNorm()(outputs, is_training=True) <- want to have this as input
        outputs = tf.nn.relu(outputs)
        outputs = snt.Conv1DTranspose(output_channels=2, kernel_shape=10, stride=1)(outputs)
        outputs = snt.BatchNorm()(outputs, is_training=True) <- want to have this as input
        outputs = tf.nn.relu(outputs)
        outputs = snt.BatchFlatten()(outputs)
        #outputs = tf.nn.dropout(outputs, keep_prob=tf.constant(1.0)) <- want to have this as input
        outputs = snt.Linear(output_size=128)(outputs)

        return outputs

    return transpose_convnet1d`

and

self._network = modules.GraphIndependent(
                edge_model_fn=EncodeProcessDecode.make_mlp_model,
                node_model_fn=EncodeProcessDecode.make_transpose_cnn_model,
                global_model_fn = EncodeProcessDecode.make_mlp_model)

I can't pass this parameter in the _build() function as shown in the following since the interface of modules.GraphIndipendent won't allow it:

    def _build(self, input_op, num_processing_steps, is_training=True):
        latent = self._encoder(input_op, is_training)

it yields:

TypeError: _build() got an unexpected keyword argument 'is_training'

ferreirafabio commented 5 years ago

How are training/validation cycles distinguished in sonnet call-back functions? Is there an example for this?

malcolmreynolds commented 5 years ago

If you have a module that you want to compose out of other modules, and some of the submodules require extra arguments like is_training, the canonical thing to do would be to define a new subclass of AbstractModule, rather than writing a function as you have. In particular, defining a module allows you to explicitly reuse variables by just calling the module twice (potentially with different is_training kwarg values).

Your example would be something like this:

class TransposeCnnModel(snt.AbstractModule):
  def __init__(self, name='transpose_cnn_model'):
    super(TransposeCnnModel, self).__init__(name=name)

  def _build(self, inputs, is_training):
    inputs = tf.expand_dims(inputs, axis=2)

    outputs = snt.Conv1DTranspose(output_channels=2, kernel_shape=10, stride=1)(inputs)
    outputs = snt.BatchNorm()(outputs, is_training=is_training)
    outputs = tf.nn.relu(outputs)
    outputs = snt.Conv1DTranspose(output_channels=2, kernel_shape=10, stride=1)(outputs)
    outputs = snt.BatchNorm()(outputs, is_training=is_training)
    outputs = tf.nn.relu(outputs)
    outputs = snt.BatchFlatten()(outputs)
    keep_prob = 0.7 if is_training else 1.0
    outputs = tf.nn.dropout(outputs, keep_prob=keep_prob)
    outputs = snt.Linear(output_size=128)(outputs)

    return outputs    

You would probably want to make more things configurable as constructor args (e.g. final output size, channels in the conv, dropout probability for training etc) but the above should fit the API that the GraphNet library expects.

As to your second point, I'm not sure exactly what you mean by Sonnet callback functions - could you be more specific?

ferreirafabio commented 5 years ago

Thank you for your answer. I was particularly interested in adapting the following deepmind/graph_nets use case from the examples:

self._network = modules.GraphNetwork(Model.make_mlp_model,
                                               Model.make_mlp_model,
                                               Model.make_mlp_model) 

where make_mlp_model are callback functions. I don't know think that the OO (i.e. passing object instances instead of function references) solution you suggest will work in this case since the API docs of modules.GraphNetwork states the following:

A callable that returns an edge model function. The callable must return a Sonnet module (or equivalent). If passed None, will pass through inputs (the default).

malcolmreynolds commented 5 years ago

According to the docstring of GraphNetwork, the callables passed must return a sonnet module or equivalent (this just means something that has a __call__ method to which you pass input tensors). In this case, the delayed execution is there so that the edge_model_fn, nodel_model_fn etc all end up with a scope that is internal to the GraphNetwork. The simplest way to do this, with the Module I posted above, would be:

self._network = modules.GraphNetwork(
    lambda: TransposeCNNModel(name='edge'),
    lambda: TransposeCNNModel(name='node'), 
    # ... etc

Obviously replace the TransposeCNNModel with whatever Sonnet module you want.

I hope this clears things up - please reopen if you still have questions.

ferreirafabio commented 5 years ago

Thank you clearing things up! I'm still facing some problems with the kind of syntax you introduced, which is why I'm reopening.

Particularly, I want to pass binary flag is_training through sess.run(). I tried using the lambda notation you introduced by passing lambda arguments (x,y) for the “inputs” and the “is_training" flag like so: node_model_fn=lambda x,y: ConvNet1D(name='CnnNodes’)(x,y)

However, this yields the error TypeError: _build() takes 2 positional arguments but 3 were given.

The code I'm using is the following:

class CNNEncoderGraphIndependent(snt.AbstractModule):

    def __init__(self, name="CNNEncoderGraphIndependent"):
        super(CNNEncoderGraphIndependent, self).__init__(name=name)

        with self._enter_variable_scope():
            self._network = modules.GraphIndependent(
              edge_model_fn=EncodeProcessDecode.make_mlp_model_edges,
              node_model_fn=lambda: ConvNet1D(name='CnnNodes'),
              global_model_fn=lambda: ConvNet1D(name='CnnGlobals')
            )

    def _build(self, inputs, is_training):
        return self._network(inputs)

class CNNDecoderGraphIndependent(snt.AbstractModule):
    ...#omitted due to brevity but equal interfaces as used in CNNEncoderGraphIndependent 

class EncodeProcessDecode(snt.AbstractModule, BaseModel):
    def __init__(self, config, name="EncodeProcessDecode"):
        super(EncodeProcessDecode, self).__init__(name=name)

        self._encoder = CNNEncoderGraphIndependent()
        self._decoder = CNNDecoderGraphIndependent()
        self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

    def _build(self, input_op, num_processing_steps, is_training):
        latent = self._encoder(input_op, is_training) 
        ...
        for _ in range(num_processing_steps):
            ...
            decoded_op = self._decoder(latent, is_training)
            ...
        return output_ops

I want to feed True/False flags in a sess.run() into the placeholder ‘self.model.is_training’ like so:

self.model = EncodeProcessDecode()
self.model.output_ops_train = self.model(self.model.input_ph, self.config.n_rollouts, self.model.is_training)

but this yields the said error. What am I missing?

ferreirafabio commented 5 years ago

Unfortunately, this issue was never reopened although there's an open question. Could this be reopened please?

diegolascasas commented 5 years ago

This code should work. Can you narrow it down to a minimal reproducible example?

ferreirafabio commented 5 years ago

it worked! Ty