bojone / bert4keras

keras implement of transformers for humans
https://kexue.fm/archives/6915
Apache License 2.0
5.37k stars 927 forks source link

load weights from pretrained weights but change segment embedding layer #207

Open 9652018 opened 4 years ago

9652018 commented 4 years ago

我有個任務想要有更多的segment id, 已透過bert_config修改。

但此時會遇到 預訓練模型的segment embedding layer的shape不符,無法順利讀入。

可以讓segment embedding 變成新的shape,但在其他layer仍然load pretrained model嗎? 換句話說就是我創建新的segment embedding layer,其他部分沿用pretrained weights。

Thanks

bojone commented 4 years ago

这个简单,自己定义一下load_variable方法即可:

from bert4keras.models import BERT, build_transformer_model

class MyBERT(BERT):
    def load_variable(self, checkpoint, name):
        variable = super(MyBERT, self).load_variable(checkpoint, name)
        if name == 'bert/embeddings/token_type_embeddings':
            variable = xxxxx  # 自己想办法得到一个适合shape的variable
        return variable

bert = build_transformer_model(config_path, checkpoint_path, model=MyBERT)
9652018 commented 4 years ago

成功了 Thanks