ZhuiyiTechnology / WoBERT

以词为基本单位的中文BERT
Apache License 2.0
458 stars 70 forks source link

转torch模型时,先导出为ckpt模型,是否需要自己导出vocab.txt并修改bert_config.json #12

Closed baiziyuandyufei closed 3 years ago

baiziyuandyufei commented 3 years ago

修改train.py

1. 构建模型

设置

bert = build_transformer_model(
    config_path,
    checkpoint_path,
    with_mlm='linear',
    keep_tokens=keep_tokens,  # 只保留keep_tokens中的字,精简原字表
    compound_tokens=compound_tokens,  # 增加词,用字平均来初始化
    return_keras_model=False, 
)

model = bert.model

2. 保存模型

添加

def on_epoch_end(self, epoch, logs=None):
        model.save_weights('bert_model.weights')  # 保存模型
        bert.save_weights_as_checkpoint(filename='ckpt_model/bert_model.ckpt')

3. 保存词汇表

添加

from bert4keras.tokenizers import Tokenizer, load_vocab, save_vocab
# 加载jieba词表的top-num_words个词,去除BERT词表中的一些词
if os.path.exists('tokenizer_config.json'):
    token_dict, keep_tokens, compound_tokens = json.load(
        open('tokenizer_config.json')
    )
    save_vocab("ckpt_model/vocab.txt", token_dict)
else:
    # 加载并精简词表
    token_dict, keep_tokens = load_vocab(
        dict_path=dict_path,
        simplified=True,
        startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'],
    )
    pure_tokenizer = Tokenizer(token_dict.copy(), do_lower_case=True)
    user_dict = []
    for w, _ in sorted(jieba.dt.FREQ.items(), key=lambda s: -s[1]):
        if w not in token_dict:
            token_dict[w] = len(token_dict)
            user_dict.append(w)
        if len(user_dict) == num_words:
            break
    compound_tokens = [pure_tokenizer.encode(w)[0][1:-1] for w in user_dict]
    json.dump([token_dict, keep_tokens, compound_tokens],
              open('tokenizer_config.json', 'w'))
    save_vocab("ckpt_model/vocab.txt", token_dict)

4. 修改bert_config.json

统计词汇表大小

# wc -l vocab.txt
33585 vocab.txt

修改bert_config.json的"vocab_size",添加"model_type"

{
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": 33585,
  "model_type":"bert"
}

5. 用WoBERT_pytorch中的转换脚本转换