mila-iqia / blocks

A Theano framework for building and training neural networks
Other
1.15k stars 351 forks source link

Proposition: replacements without replacements #982

Open dmitriy-serdyuk opened 8 years ago

dmitriy-serdyuk commented 8 years ago

After struggling with multiple replacements in the graph I started to think how to implement the same functionality without theano.clone.

Problems with the current approach

The current way of modifying graph is mostly used for regularization methods. One usually creates an inference graph and applies modifications to turn it into a training graph.

x = tensor.matrix('x')
valid_cost = my_mlp.apply(x)
valid_cg = ComputationGraph(valid_cost)
train_cg = apply_dropout(valid_cg, ...)
train_cost, = train_cg.outputs

Current way of changing the computaional graph has several downsides:

My idea is to use our tagging mechanism and the bricks structure to make simple replacement. Namely, I want easily (if it is possible) replace inputs/outputs/parameters of the bricks.

I propose to use a manager which tracks the brick applications and makes the substitutions defined by the user. The code might look like this:

validation_cost = my_mlp.apply(x)
manager = dropout_manager(bricks=my_mlp.linear_layers, roles=INPUT)
train_cost = my_mlp.apply(x, use_manager=manager)

Internally, the use_manager argument is passed to the application and the application performs actions on inputs/outputs of the application and parameters of the bricks. This should be enough for most types of graph modifications and it should be easy to create a new kind of modifiable brick.

Optionally, it's possible to have a global manager and use with statement like

with dropout_manager(bricks=my_mlp.linear_layers, roles=INPUT):
    train_cost = my_mlp.apply(x)

It this case, all the applications inside with block should use the dropout manager by default.

Managers

A manager should be a filter-like object (it cannot use the VariableFilter since the graph doesn't not exist at the time when the manager is created). A manager should accept a set of callbacks which tell how to change the inputs, the outputs of the application and the parameters of the brick associated with the application. It can be associated with the roles.

class Manager:
    def __init__(self, callbacks, filters):
        self.filters = filters
        self.callbacks = callbacks
    def process_application(application, brick, var):
        if match(application, brick, self.filters):
             modified_var += self.callbacks[var.role](var)
        return modified_var

The application class should get the manager and preprocess the input, apply the computation, defined in the brick, and postprocess the output.

Changes to be made for bricks

The parameters of the brick are often should be modified (like in weight noise). We already have the parameter list, we can have a wrapper around this list. And when the wrapper is asked for a particular parameter, it should call the manager before. It goes without saying, that it should be documented, that people are not supposed to use the parameters manually it their own bricks if they want to use the power of the replacements.

Example: dropout

For the sake of example, I simplify the code:

def dropout_manager(filter_conditions, prob=0.5):
    def callback(var):
        # Do all the stuff we currently do in our dropout
        return var * rng.binomial(var.shape, prob) * (1 - prob)
    return Manager(callbacks={INPUT: callback}, filters=filter_conditions)

As you can see, the implementation of dropout remains as easy as it was using the graph.replace as soon as we implement the Manager and improve the application decorator.

Disadvantages

dwf commented 8 years ago

The context manager approach is quite similar to my alternative interface for managing batch normalization in graphs. It's a bit differently focused and solves slightly different class of problems. (The meat of what's going on is in the brick).

Several people have commented that despite the fact that they have to build the graph twice that they like this context manager approach, had asked me if Blocks could perhaps grow a similar interface for dropout. I wasn't sure about this, since having a Dropout brick carries a lot less value than having a BatchNormalization brick (somebody has to own parameters, in the latter case; I suppose a dropout brick could own a shared variable containing the dropout probability, which could then be modified/annealed during training).

dwf commented 8 years ago

I've also had the problems you list with graph replacements. The other problem I've had is that early in my usage of it, I wasn't quite filtering for the right variables,

There are some argument to just introducing bricks that contain switches in their apply method like BatchNormalization does.

Arguments against that I can see:

dmitriy-serdyuk commented 8 years ago

@dwf , sure, I got some inspiration from your implementation of batch normalization. I'm trying to have a more general solution which makes use of the stuff we made before, like tagging and application hierarchy.

After an offline conversation with @rizar , I can see that it might a big problem that the replacements I propose can be only local, in a sense that one cannot use a variable from 6th layer to make substitution in the 1st.

Another point, that I want to make the most perplexed part of blocks, the application decorator, even more complicated.

I was also thinking about how to change the parameters and it seems more complicated than it seemed to me before. I can only see now a solution which involves a stateful object containing parameters.

class ParameterContainer:
    def __init__(shared_vars):
         self.current_manager = DummyManager()
         self.current_application = None
         self.shared_vars = shared_vars
    def __getitem__(self, index):
         return self.current_manager.process_application(
              self.current_application, brick, self.shared_vars[index])

The application changes the manager and the current_application before and resets it to the original state after.

rizar commented 8 years ago

Thanks for taking time to write down your thoughts, @dmitriy-serdyuk !

For the record, the difference between the proposal here and what @dwf used for batch norm is the following:

The latter is more challenging, but if we can reliably do it, there is no need for one-trick ponies.

A consequence of the previous point is that the aggregations schemes inside the monitoring are broken after the replacement.

This is not true. Aggregation schemes are broken because replacement do not affect auxiliary variables. That said, it is still true aggregation schemes and replacements do not work well together.

rizar commented 8 years ago

One way to summarize your proposal is to have things inserted in the computation graph during the construction time, by means of using callbacks. This is a reasonable idea, but for me it is not entirely clear that it will be much easier to use, that it will scale well (see your complaints about replacements of replacements) and etc.. Also, as you said, it is not clear how to replace parameters without adding even more communication between different objects.

I think we should keep this idea in mind, but so far it looks it would not give us full relief.