google / objax

Apache License 2.0
768 stars 77 forks source link

Training state as a Module attribute #29

Closed rwightman closed 3 years ago

rwightman commented 4 years ago

As mentioned in a Twitter thread, I am curious about the decision to propagate training state through the call() chain. From my perspective this approach adds more boilperplate code, and more chance of making a mistake (not propagating the state to a few instances of a module with a BN or dropout layer, etc). If the state changed every call like the input data, it would make more sense to pass it with every forward, but I can't think of cases where that is common? For small models it doesn't make much difference, but as they grow with more depth and breadth of submodules, the extra args are more noticeable.

I feel one of the major benefits of an OO abstraction for NN is being able to push some attributes like this into the class structure vs forcing it to be forwarded through every call in a functional manner. I sit in the middle ground (pragmatic) of OO vs functional. Hidden state can be problematics, but worth it if it keeps interfaces clean.

Besides TF/Keras, most DL libs managetraining state as module attr or some sort of context

It should be noted that Swift for TF started out Keras and objax like with the training state passed through call().

Disclaimer: I like PyTorch, I do quite a bit of work with that framework. It's not perfect but I feel they really did a good job in terms of interface, usibility, evolution of the API. I've read some other comments here and acknowledge the 'we don't want to be like framework/lib X, or Y just because. If you disagree go fork yourself'. Understood, any suggestions I make are not just to be like X, but to bring elemtents of X that work really well to improve this library.

I currently maintain some PyTorch model collections, https://github.com/rwightman/pytorch-image-models and https://github.com/rwightman/efficientdet-pytorch as examples. I'm running into a cost ($$) wall with experiments supporting my OS work and experiments re GPU. TPU costing is starting to look far more attractive. PyTorch XLA is not proving to be a great option but JAX with a productive interface looks like it could be a winning solution with even more flexibility .

I'm willing to contribute code for changes like this, but at this point it's matter of design philosophy :)

AakashKumarNain commented 4 years ago

Here are my thoughts on this one:

  1. If we consider the current design, then going at Module level makes sense. You have a trainable attribute that can be toggled during training/evaluation phase. The good thing about this is that a module can recursively toggle the trainable attributes for each of the members/layers in this case. Clean and simple.

  2. But the above design has a flaw(at least IMO). Let's say you have a ResNet50 model defined as a Module. The above design allows to train from scratch, transfer learn or fine-tune the whole network. But there are situations where we don't want to fine-tune all the layers at once. For example, if you take apretrained imagenet model and apply on a medical dataset, in that scenario, people freeze the first few layers only and fine-tune the rest. If you are setting the trainable attribute at module level, you won't be able to achieve this until unless you redesign your network as combination of different modules and freeze a few of them.

  3. Because of 2), I prefer setting trainable attribute at layer level. A Module in some sense is juts a high-level layer that comprises of other layers. This helps to achieve every state: train from scratcg, transfer learn, fine-tune, freeze some..etc

If we implement 3), we can still set the trainable attribute recursively for each layer. Though I believe that would require a redesign of the Module api

rwightman commented 3 years ago

@AakashKumarNain Not sure what the distinction betwen module and layer is? With the design here, and PyTorch, layers like Conv2D and BatchNorm, etc are Modules, so the granularity is quite fine. When you call train() or eval() method on the model level module (ie `ResNet50()) it recursively propagates to all children down to the individual Conv2D layers. One can override the methods on any sub-module to customize the behaviour.

Module based w/ recursion also lets you do something like below, apply is a member function that applies a given fn recusively to all modules in the model hierarchy.

def set_bn_eval(module):
    if isinstance(module, nn.modules.BatchNorm):
        module.eval()

model.apply(set_bn_eval)
AakashKumarNain commented 3 years ago

@rwightman Ohh I see. Not a PyTorch user so I thought of Module being different than layer. If that's the case, then we are on the same page, and I agree that trainable state should be at Module level in this case.

AlexeyKurakin commented 3 years ago

Thanks for the feedback. We do want to get the best design ideas into Objax.

Below I’ll explain the rationale behind the current design decision. I’ll also provide some additional context for people who may be reading it later.

Training / evaluation mode

Certain types of NN layers (like BatchNorm and Dropout) have different behavior in training and evaluation mode. For example, in training mode, Dropout randomly drops some neurons (i.e. zeros some values in the activation vector), while in eval mode Dropout acts as identity op.

Thus any NN library has to provide a way to switch between training and evaluation mode when calling layers. Different libraries do it differently.

NOTE: that this is different from trainable/non-trainable variables. Trainable/not-trainable variables determine which variables are updated by optimizer.

Attribute-style of switching train/eval mode (like in PyTorch)

Each layer/module has a module.training attribute which controls whether the module runs in training or eval mode. There are also methods module.train() and module.eval() which change value of this attribute for current module and all submodules:

model = create_model(...)

# training
model.train()
# … perform training step

# evaluation
model.eval()
# … run evaluation

Similar approach is used in Keras (before TF2), MxNet Gluon and Swift for TF frameworks.

Argument-style of switching train/eval mode (like in Objax currently)

Module’s call function has a training argument. To call a module in train mode user writes module(x, training=True). To call module in eval mode user does module(x, training=False):

model = create_model(...)

# training
loss = cross_entropy(y, model(x, training=True))
# … compute gradient of loss and update weights

# evaluation
predictions = model(x, training=False)

Similar approach is used in TF2 and some other JAX frameworks (Haiku, Flax).

Comparison of approaches

So basically choosing approach get to balancing act between following:

  1. Safety: preventing users from making silent and hard to detect mistakes
  2. Ability to set training mode per module.
  3. Verbosity of the code.

Below I described more detailed considerations for each of these three factors.

Safety

We think that is a pretty big issue for a PyTorch-style approach. It’s definitely a big problem for novice users. It might be less of a problem for experienced users, but even experienced users may accidentally forget to type model.train() or model.eval() in their code and then spend a lot of time figuring out why model accuracy is different from what they expected.

Specifically, here is one of the official PyTorch examples which does not set model to evaluation mode: https://github.com/pytorch/examples/blob/master/dcgan/main.py

And here are several discussions which indicating that novice users don’t always understand difference between model.train() and model.eval() thus just always running model in default (training) mode:

On the other hand, if user always have to provide training argument, then it will

  1. Prevent forgetting of model.train() and model.eval() calls
  2. Force novice users to figure out the purpose of this argument instead of just using some default.

Per module training/eval mode.

One nice thing about PyTorch-style training/eval mode, is that it allows to easily set mode independently for each module:

model = Resnet50(nclasses=1000)
…
# Here is example how to set most of the network,
# except few modules into training mode
model.train()
model.block_1.bn_1.eval()
model.block_2.bn_2.eval()

This could be one big advantage in favor of PyTorch way of specifying training mode. However it’s not clear how common this feature is used and it’s possible to do similar thing in Objax just a little more verbose:

model = Resnet50(nclasses=1000)
…
# Here is example how to force certain batch norms into eval mode
model.block_1.bn_1 = functools.partial(model.block_1.bn_1, training=False)
model.block_2.bn_2 = functools.partial(model.block_2.bn_2, training=False)

# following line will call model in training mode, except for two block_1.bn_1 and block_2.bn_2
y = model(x, training=True)

If there is interest from users in this feature, we can also add a helper function for more comprehensive overriding of arguments of modules.

Verbosity

As @rwightman mentioned, passing training argument around might be somewhat verbose.

This is somewhat alleviated by the fact that objax.Sequential automatically propagates it where needed and most of the models have large sequential chunks. Thus I think for many practical models, only a moderate (or even very little) amount of manual propagation of training argument is needed. For example, we have only one line manually propagating training argument in our resnet implementation

Conclusion

Overall current Objax approach helps to prevent user error, while PyTorch approach can lead to silent errors.

It does allow the user to set per module training/eval mode, but maybe not as convenient as in PyTorch. Plus support of per module training/eval mode could be improved with additional helper methods.

While Objax approach tends to be more verbose compared to PyTorch, it could be alleviated by use of objax.Sequential.

Anyway, let us know if there are other advantages of PyTorch way to switch train/eval mode, or if there are any other issues with how Objax handles it.

AakashKumarNain commented 3 years ago

@AlexeyKurakin thanks for the detailed feedback. I have just one question. In Keras, everything is a layer, even a model instance. Every layer has a trainable attribute. Also, a layer is a callable object and also accepts training as an argument.

training is a boolean argument in call that determines whether the call should be run in inference mode or training mode. trainable is a boolean attribute that freezes a layer, meaning that if layer.trainable=False then the trainable weights aren't updated by the optimizer for that layer. For BN, this is an exception where setting layer.trainable=False means the layer is actually running in inference mode.

Can we achieve the same with Objax's Sequential api?

rwightman commented 3 years ago

@AlexeyKurakin Indeed, thanks for the detailed explanation.

I do agree regarding safety/errors. I doubt anyone has dabbled in PyTorch without forgetting to toggle their model to eval mode before using it for prediction. I've certainly made the mistake. Now it's burnt into my brain's debugging checklist so it's not high on my priority list. On my first pass through this code I didn't note that the training arg isn't defaulted, so yeah, no concerns about missing it for a few layers like I'd originally thought.

Some work to make overriding a bit easier would ease that concern. It's fairly common to see BN in eval mode and/or full freezing (eval mode and no grad) of BN and other layers for specific stages when using pretrained classification models as backbones in obj detection / segmentation models. Not sure if it's still popular these days, but the 'dropout as bayesian approx' crowd liked to keep dropout active in eval mode once upon a time.

I often use sequential closer to the top level of a model (feature/block repeats) but tend to avoid it within the blocks themselves.

AlexeyKurakin commented 3 years ago

So it seems like main feature request would be better support of per-module train/eval state. I opened separate issue https://github.com/google/objax/issues/37 to track progress on design and implementation of per-module train/eval state. I'm going to close this issue, since everything else seem to be addressed. Feel free to reopen, if there are some other questions.

AlexeyKurakin commented 3 years ago

@AakashKumarNain regarding Keras trainable and training attributes. I actually was unaware that in Keras setting bath_norm_layer.trainable=False will make it run in inference mode. Seems like this change was made only in TF2: https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/layers/normalization_v2.py#L208-L247

While motivation for such behavior is understandable, it's not consistent between batch_norm and other types of layers and it could be confusing. And generally we want to avoid confusing behavior in Objax. Also in Objax there is no trainable property and the list of trainable variables is explicitly passed to optimizer.

Ability to always force certain Objax modules to run in eval mode will be addressed in #37 and on top of it batch norm variables should be excluded from list of vars passed to optimizer. These two things will achieve frozen batch norm behavior.

cgarciae commented 3 years ago

Hey this issue was super useful, thanks for sharing your ideas. I've been reimplementing Elegy's module system a bit based on some of the ideas brought by Objax, after starting with a Haiku-based implementation I realized that being functional makes it way harder to do stuff like transfer learning or anything that involved mixing preexisting (pretrained) Modules.

Regarding the training state I decided to go the S4TF route of having a per-thread global state + using context managers for controlling it when needed. Some notes: