Closed baiziyuandyufei closed 3 years ago
修改train.py
设置
bert = build_transformer_model( config_path, checkpoint_path, with_mlm='linear', keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表 compound_tokens=compound_tokens, # 增加词,用字平均来初始化 return_keras_model=False, ) model = bert.model
添加
def on_epoch_end(self, epoch, logs=None): model.save_weights('bert_model.weights') # 保存模型 bert.save_weights_as_checkpoint(filename='ckpt_model/bert_model.ckpt')
from bert4keras.tokenizers import Tokenizer, load_vocab, save_vocab # 加载jieba词表的top-num_words个词,去除BERT词表中的一些词 if os.path.exists('tokenizer_config.json'): token_dict, keep_tokens, compound_tokens = json.load( open('tokenizer_config.json') ) save_vocab("ckpt_model/vocab.txt", token_dict) else: # 加载并精简词表 token_dict, keep_tokens = load_vocab( dict_path=dict_path, simplified=True, startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'], ) pure_tokenizer = Tokenizer(token_dict.copy(), do_lower_case=True) user_dict = [] for w, _ in sorted(jieba.dt.FREQ.items(), key=lambda s: -s[1]): if w not in token_dict: token_dict[w] = len(token_dict) user_dict.append(w) if len(user_dict) == num_words: break compound_tokens = [pure_tokenizer.encode(w)[0][1:-1] for w in user_dict] json.dump([token_dict, keep_tokens, compound_tokens], open('tokenizer_config.json', 'w')) save_vocab("ckpt_model/vocab.txt", token_dict)
统计词汇表大小
# wc -l vocab.txt 33585 vocab.txt
修改bert_config.json的"vocab_size",添加"model_type"
{ "attention_probs_dropout_prob": 0.1, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "max_position_embeddings": 512, "num_attention_heads": 12, "num_hidden_layers": 12, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "type_vocab_size": 2, "vocab_size": 33585, "model_type":"bert" }
修改train.py
1. 构建模型
设置
2. 保存模型
添加
3. 保存词汇表
添加
4. 修改bert_config.json
统计词汇表大小
修改bert_config.json的"vocab_size",添加"model_type"
5. 用WoBERT_pytorch中的转换脚本转换