Open 9652018 opened 4 years ago
这个简单,自己定义一下load_variable
方法即可:
from bert4keras.models import BERT, build_transformer_model
class MyBERT(BERT):
def load_variable(self, checkpoint, name):
variable = super(MyBERT, self).load_variable(checkpoint, name)
if name == 'bert/embeddings/token_type_embeddings':
variable = xxxxx # 自己想办法得到一个适合shape的variable
return variable
bert = build_transformer_model(config_path, checkpoint_path, model=MyBERT)
成功了 Thanks
我有個任務想要有更多的segment id, 已透過bert_config修改。
但此時會遇到 預訓練模型的segment embedding layer的shape不符,無法順利讀入。
可以讓segment embedding 變成新的shape,但在其他layer仍然load pretrained model嗎? 換句話說就是我創建新的segment embedding layer,其他部分沿用pretrained weights。
Thanks