rwth-i6 / returnn_common

Common building blocks for RETURNN configs, such as models, training concepts, etc
7 stars 4 forks source link

Allow `__init__` logic to work equally for graph-based and eager-based backends, specifically re-parameterization like weight norm #250

Open albertz opened 1 year ago

albertz commented 1 year ago

We want to support multiple backends for nn, such as RETURNN, TensorFlow and PyTorch (see rwth-i6/returnn#1264). This implies, we need to design our API in a way that it works both with eager-mode and graph-mode backends.

This issue here is via the comment https://github.com/rwth-i6/returnn/issues/1264.

It's best explained on the weight norm (#91) example, or more generally any re-parameterization. The current weight norm code only makes sense for graph mode, as it symbolically redefines the parameter in terms of some other symbolic formula. This formula thus needs to be evaluated again and again, everytime some computation is done with the model. This naturally works with graph-mode. For eager-mode, this must be more explicit, done when the actual parameter is accessed, as the parameter is not a symbolic formula.

Note this is different for other code in __call__. This code should work no matter if it is executed in graph-mode or eager-mode. And any control flow logic is already wrapped.

However, in __init__, there is an important difference. In each case, this is executed only once. With symbolic computation, represententing some value e.g. based on a parameter, for example weight normalized parameters, this is totally fine and the right thing to do for symbolic execution. However, in case of eager execution, only executing it once is not helpful. E.g. in PyTorch, weight normalization will use _forward_pre_hooks to calculate it again and again.

So far we only defined parameters in __init__, and maybe their initial values (nn.init.ParamInit) or maybe things like weight decay. This is fine for both eager and symbolic mode.

However, for any computation depending on a parameter which can potentially change, we need to think about this. It's not clear yet how to solve this. This becomes relevant for example for weight norm (#91).

albertz commented 1 year ago

Actually related is rwth-i6/returnn_common#96, explicit stages. Param init is a different stage then the step loop over the dataset. I think this must be explicit, to properly handle this.

albertz commented 1 year ago

I guess explicit stages are the natural way to do this. As shown in rwth-i6/returnn_common#96, those could be defined via context scopes. But we can also allow a simpler variant via explicit function, like nn.set_execution_stage_epoch_loop_steps() or nn.enter_step_loop() or so.

albertz commented 1 year ago

When thinking about this, just thinking about one stage for param init, and another for the training steps, this does not really cover one of the main aspects here, e.g. such re-parameterization like weight norm. However we design this, we should also think about how such re-parameterization would be done then. Maybe it becomes similar as PyTorch _forward_pre_hooks in any case then.

albertz commented 1 year ago

A problem with a forward-pre-hook: When would this be executed? In PyTorch, it is in the __call__, before the forward gets called. However, it is not used for any other method of the module. Do we want to also have it for any other method?

One potential solution: Maybe custom Python descriptors? However, as I understand, they can not be installed into the module instance, but only into the class, which makes this a bit problematic. Although, maybe we could also dynamically create a new class?