Closed gpengzhi closed 5 years ago
Looks good. I'm thinking, would this problem be solved if we have a mechanism that can:
- Add methods that is called prior to
__init__
(of the uppermost subclass).- Add methods that is called after
__init__
(of the uppermost subclass).I feel that such mechanism could solve both this problem and the repeated-hparams initialization problem we've discussed before.
A hacky way of doing this is via metaclasses. When an instance of a class is created, the metaclass
__call__
method is invoked, which internally calls the class__new__
and class__init__
methods. We can create a custom metaclass forModuleBase
that constructs thehparams
object before__init__
, and allows sub-metaclasses to append method calls after__init__
.A diagram of instance creation is as follows, taken from this very good article that explains it all: `
Thanks a lot. I think this modification is related to #41 I will think about it.
Why is this PR related to "repeated-hparams initialization"? In particular, what's the init
argument in the XLNetEncoder __init__
for in the first place?
It's because these two problems have similar root causes.
The init
argument. Our pre-trained mixin interfaces requires the concrete classes to call init_pretrained_weights()
at the end of their __init__
methods. This method is used to initialize the weights of the registered module parameters from the checkpoint file.
XLNetEncoder
, which inherits PretrainedXLNetMixin
, follows this pattern. However, XLNetDecoder
inherits XLNetEncoder
(which is not a good design; but it works at least for now), but has an extra parameter lm_bias
that has to be registered before calling init_pretrained_weights()
, which happens with the super().__init__
call. Due to PyTorch implementations, we can't register parameters before super().__init__
(which internally calls nn.Module.__init__
), so we had this init
argument to control whether we call init_pretrained_weights()
in XLNetEncoder.__init__
.
Relation to repeated-hparams
creation. The hparams
issue was that: hparams
is created in ModuleBase.__init__
, which will be called in super().__init__
of our built-in modules. However, certain base classes require additional arguments in their constructors (e.g., DataBase
require a data_source
argument), and some arguments cannot be constructed without knowing hparams
values. As a results, the hparams
object is constructed multiple times.
The init
issue above could benefit from being able to add method calls (a call to init_pretrained_weights()
) at the end of object initialization. The hparams
issue could similarly benefit from being able to add calls before initialization (a call to construct hparams
).
Thanks for explaining, and the possible solution.
The solution can be a candidate we can apply at some point. Since it's a bit intricate, for now (when the issue has not widely occurred), let's just work around case by case.
One of our interface design "principles" could to some extent (not completely) avoid the issue: put hyperparameters in hparams
instead of constructor arguments whenever possible. In this way, it's in general unlikely (though not impossible) that an hyperparam is in the hparams
of a class but at the same time is an argument of super().__init__
It looks the CharCNN
example in the issue https://github.com/asyml/texar-pytorch/issues/41 is out of the scope of the above principle though. It's reasonable that in_channels
is Conv1DEncoder.__init__
argument and char_embed_dim
is in CharCNN
hparams
.
Yes, we could leave this for the future.
While the issue in the CharCNN
example could be avoided with this principle, it only applies to values that fits in hparams
-- and a data source is probably not one of them. There could also be logic that has to be done before calling the super class constructor.
Requested by Zhiting's comments in PR in texar-tf
Make the interface of
XLNetEncoder
be consistent with other modules.