asyml / texar-pytorch

Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation. This is part of the CASL project: http://casl-project.ai/
https://asyml.io
Apache License 2.0
745 stars 117 forks source link

Refactor XLNet interface #208

Closed gpengzhi closed 5 years ago

gpengzhi commented 5 years ago

Requested by Zhiting's comments in PR in texar-tf

Make the interface of XLNetEncoder be consistent with other modules.

gpengzhi commented 5 years ago

Looks good. I'm thinking, would this problem be solved if we have a mechanism that can:

  • Add methods that is called prior to __init__ (of the uppermost subclass).
  • Add methods that is called after __init__ (of the uppermost subclass).

I feel that such mechanism could solve both this problem and the repeated-hparams initialization problem we've discussed before.

A hacky way of doing this is via metaclasses. When an instance of a class is created, the metaclass __call__ method is invoked, which internally calls the class __new__ and class __init__ methods. We can create a custom metaclass for ModuleBase that constructs the hparams object before __init__, and allows sub-metaclasses to append method calls after __init__.

A diagram of instance creation is as follows, taken from this very good article that explains it all: `

Thanks a lot. I think this modification is related to #41 I will think about it.

ZhitingHu commented 5 years ago

Why is this PR related to "repeated-hparams initialization"? In particular, what's the init argument in the XLNetEncoder __init__ for in the first place?

huzecong commented 5 years ago

It's because these two problems have similar root causes.

The init argument. Our pre-trained mixin interfaces requires the concrete classes to call init_pretrained_weights() at the end of their __init__ methods. This method is used to initialize the weights of the registered module parameters from the checkpoint file.

XLNetEncoder, which inherits PretrainedXLNetMixin, follows this pattern. However, XLNetDecoder inherits XLNetEncoder (which is not a good design; but it works at least for now), but has an extra parameter lm_bias that has to be registered before calling init_pretrained_weights(), which happens with the super().__init__ call. Due to PyTorch implementations, we can't register parameters before super().__init__ (which internally calls nn.Module.__init__), so we had this init argument to control whether we call init_pretrained_weights() in XLNetEncoder.__init__.

Relation to repeated-hparams creation. The hparams issue was that: hparams is created in ModuleBase.__init__, which will be called in super().__init__ of our built-in modules. However, certain base classes require additional arguments in their constructors (e.g., DataBase require a data_source argument), and some arguments cannot be constructed without knowing hparams values. As a results, the hparams object is constructed multiple times.

The init issue above could benefit from being able to add method calls (a call to init_pretrained_weights()) at the end of object initialization. The hparams issue could similarly benefit from being able to add calls before initialization (a call to construct hparams).

ZhitingHu commented 5 years ago

Thanks for explaining, and the possible solution.

The solution can be a candidate we can apply at some point. Since it's a bit intricate, for now (when the issue has not widely occurred), let's just work around case by case.

One of our interface design "principles" could to some extent (not completely) avoid the issue: put hyperparameters in hparams instead of constructor arguments whenever possible. In this way, it's in general unlikely (though not impossible) that an hyperparam is in the hparams of a class but at the same time is an argument of super().__init__

It looks the CharCNN example in the issue https://github.com/asyml/texar-pytorch/issues/41 is out of the scope of the above principle though. It's reasonable that in_channels is Conv1DEncoder.__init__ argument and char_embed_dim is in CharCNN hparams.

huzecong commented 5 years ago

Yes, we could leave this for the future.

While the issue in the CharCNN example could be avoided with this principle, it only applies to values that fits in hparams -- and a data source is probably not one of them. There could also be logic that has to be done before calling the super class constructor.