rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
350 stars 131 forks source link

RF weight dropout and variational noise #1518

Open albertz opened 1 month ago

albertz commented 1 month ago

Currently we don't have weight dropout in the RF. We should add it.

(I thought there was an issue already about it but I don't find it.)

Related:

In general, there is a whole class of similar features on parameter reparameterizations, which would require a similar mechanism, like weight norm.

Regarding implementation:

I think we could follow the PyTorch implementation of similar logic (e.g. weight norm) by using a forward-pre-hook. We already have support for hooks via rf.Module.register_forward_hook/rf.hooks.setup_post_hook_on_method.

Regarding PyTorch:

So, comparing the modern parametrization API register_parametrization to the way it was done in torch.nn.utils.weight_norm via register_forward_pre_hook:

Concluding from that, I'm a bit unsure what way to go for RF... Using register_forward_pre_hook looks too error-prone (only covers the hooked function, nothing else)... but the other approach looks too complicated? But maybe still better. The caching mechanism is maybe also not so important for now? For all use cases, I think it would not matter (e.g. rf.Linear, rf.SelfAttention, rf.Conv, etc.).

Further, we should also support this with gradient checkpointing such that the weights are not stored twice in memory. In our existing TF implementation of variational noise, we already use gradient checkpointing, where only the random number generator state is stored and not the dropout mask nor the weight. Thus there is almost no memory overhead. See gradient_checkpoint_scope and co. For PyTorch, it is currently unclear how to do this. I moved this over to a separate issue: #1552

albertz commented 1 month ago

Btw, regarding gradient checkpointing, see this current code as an example for variational noise in our TF code:

if param_variational_noise and param.dtype.is_floating and isinstance(param, tf.Variable):
    with default_control_flow_ctx():  # make independent from loop/cond
        with reuse_name_scope_of_tensor(param, postfix="_variational_noise", add_tensor_name=True):

            def _apply_var_noise():
                rnd_state = tf_util.StatelessRandomSeed.create(shape=tf_util.get_shape(param))
                with gradient_checkpoint_scope():
                    noise = rnd_state.normal(stddev=param_variational_noise, dtype=param.dtype.base_dtype)
                    return param + noise

            param = self.network.cond_on_train(
                fn_train=_apply_var_noise,
                fn_eval=lambda: param,
            )

Specifically, check the code of gradient_checkpoint_scope and prepare_gradient_checkpointing.

I know that people also do gradient checkpointing in PyTorch, but I don't know exactly how that works.

NeoLegends commented 4 weeks ago

There is a gradient checkpointing API in PT: https://pytorch.org/docs/stable/checkpoint.html

It even saves/restores the RNG state so we could do Dropout in there. I'm not sure the RNG state there can be made explicit, but it seems suitable in all the other ways.

I saw you asking in the PT issue about JAX: the RNG there is by definition stateless, and follows a design where the RNG seed has to be threaded through the code and explicitly "split" to make new seeds.

Copying from https://jax.readthedocs.io/en/latest/jax.random.html:

seed = 1701
num_steps = 100
key = jax.random.key(seed)
for i in range(num_steps):
  key, subkey = jax.random.split(key)
  params = compiled_update(subkey, params, next(batches))
NeoLegends commented 4 weeks ago

It seems to me the API PT exposes for gradient checkpointing could be used as the RF frontend API and for the associated TF-backed implementation as well?

albertz commented 4 weeks ago

There is a gradient checkpointing API in PT: https://pytorch.org/docs/stable/checkpoint.html

Yea that is what I referred to when we talked about it. But I need to check it more how it is done there. Specifically, I'm still not exactly sure how I get what I want: that the dropout outputs are not stored but recomputed.

I saw you asking in the PT issue about JAX

No, I did not ask about JAX in there?

NeoLegends commented 4 weeks ago

Yea that is what I referred to when we talked about it. But I need to check it more how it is done there. Specifically, I'm still not exactly sure how I get what I want: that the dropout outputs are not stored but recomputed.

Yeah it would seem to me like applying only the dropout operation within the gradient checkpointed context might not be enough, but one would have to move more of the layer functionality into the checkpointed/recomputed area? Is this what you're referring to?

albertz commented 4 weeks ago

Yeah it would seem to me like applying only the dropout operation within the gradient checkpointed context might not be enough, but one would have to move more of the layer functionality into the checkpointed/recomputed area? Is this what you're referring to?

I don't know how this works. I don't want to recompute whatever comes after the dropout. I just don't want that it stores the dropout output in memory for the backprop, i.e. that it recomputes the dropout.

albertz commented 1 week ago

(Note, I made a separate issue just for the gradient checkpointing aspect in PyTorch: #1552. So this issue here can just focus on the RF specific question on how to implement weight dropout (or also weight noise / variational noise).)