rwth-i6 / returnn_common

Common building blocks for RETURNN configs, such as models, training concepts, etc
7 stars 4 forks source link

Register forward hook #242

Closed albertz closed 1 year ago

albertz commented 1 year ago

Similar as PyTorch Module.register_forward_hook.

In PyTorch, this registers some function, so that when you call mod(...), it will internally call mod.forward (as usual) and then afterwards your hook function. The hook function can also optionally change the output. (Note there is also register_forward_pre_hook, where the hook is called before mod.forward).

In returnn-common, we don't have this special handling of the forward/__call__ function. But we can still do sth similar, in a more generic way.

Specifically, I propose an interface like:

def register_call_hook(func_or_module, hook):
  ...

register_call_hook(module, hook) will be an alias for register_call_hook(module.__call__, hook).

Internally, for methods, it can check __self__ and __func__.__name__ to get the module instance and attrib, to be able to overwrite the attrib.

For global functions, we can check __module__ and __name__.

It's ok if other cases would not work for now. We could later extend this. But the module method functions, esp __call__, are most relevant, and cover what you would have in PyTorch.

Then, we can construct a new wrapper function, and overwrite the original function by that.

We can also return sth like RemovableHandle.

albertz commented 1 year ago

This is implemented now. The function is nn.register_call_post_hook.