In PyTorch, this registers some function, so that when you call mod(...), it will internally call mod.forward (as usual) and then afterwards your hook function. The hook function can also optionally change the output. (Note there is also register_forward_pre_hook, where the hook is called before mod.forward).
In returnn-common, we don't have this special handling of the forward/__call__ function. But we can still do sth similar, in a more generic way.
Specifically, I propose an interface like:
def register_call_hook(func_or_module, hook):
...
register_call_hook(module, hook) will be an alias for register_call_hook(module.__call__, hook).
Internally, for methods, it can check __self__ and __func__.__name__ to get the module instance and attrib, to be able to overwrite the attrib.
For global functions, we can check __module__ and __name__.
It's ok if other cases would not work for now. We could later extend this. But the module method functions, esp __call__, are most relevant, and cover what you would have in PyTorch.
Then, we can construct a new wrapper function, and overwrite the original function by that.
Similar as PyTorch
Module.register_forward_hook
.In PyTorch, this registers some function, so that when you call
mod(...)
, it will internally callmod.forward
(as usual) and then afterwards your hook function. The hook function can also optionally change the output. (Note there is alsoregister_forward_pre_hook
, where the hook is called beforemod.forward
).In returnn-common, we don't have this special handling of the
forward
/__call__
function. But we can still do sth similar, in a more generic way.Specifically, I propose an interface like:
register_call_hook(module, hook)
will be an alias forregister_call_hook(module.__call__, hook)
.Internally, for methods, it can check
__self__
and__func__.__name__
to get the module instance and attrib, to be able to overwrite the attrib.For global functions, we can check
__module__
and__name__
.It's ok if other cases would not work for now. We could later extend this. But the module method functions, esp
__call__
, are most relevant, and cover what you would have in PyTorch.Then, we can construct a new wrapper function, and overwrite the original function by that.
We can also return sth like
RemovableHandle
.