Closed oasis-0927 closed 4 years ago
问题描述: 你好,我在尝试用bert官方的run_classifier.py的格式,对于两个输入进行分类(即text_b不为空)时,调用 bert-base-serving-start -model_dir C:\workspace\python\BERT_Base\output\ner2 \ -bert_model_dir F:\chinese_L-12_H-768_A-12 -model_pb_dir C:\workspace\python\BERT_Base\model_pb_dir -mode CLASS -max_seq_len 202 部署服务,在测试集进行预测时,发现准确率降低,而使用官方的方法进行预测则准确率不会降低。
原因排查: 个人发现是在将与训练结果转换为.pb的二进制文件时,缺少了一个"input_type_ids" 这个feature(在训练时这个feature命名为segment_ids,用于区分text_a和text_b。可见https://github.com/google-research/bert/blob/master/run_classifier.py#L410)。 具体位置为: https://github.com/macanv/BERT-BiLSTM-CRF-NER/blob/12fb822bd847f24e4de1beca8af9de34d1a794d2/bert_base/server/graph.py#L343
解决方式:
segment_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'segment_ids')
loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=None, num_labels=num_labels)
segment_ids=features["input_type_ids"]
input_map = {"input_ids": input_ids, "input_mask": input_mask,"segment_ids":segment_ids}
欢迎提交PR
多谢分享和PR贡献。。已成功训练并完成部署。
问题描述: 你好,我在尝试用bert官方的run_classifier.py的格式,对于两个输入进行分类(即text_b不为空)时,调用 bert-base-serving-start -model_dir C:\workspace\python\BERT_Base\output\ner2 \ -bert_model_dir F:\chinese_L-12_H-768_A-12 -model_pb_dir C:\workspace\python\BERT_Base\model_pb_dir -mode CLASS -max_seq_len 202 部署服务,在测试集进行预测时,发现准确率降低,而使用官方的方法进行预测则准确率不会降低。
原因排查: 个人发现是在将与训练结果转换为.pb的二进制文件时,缺少了一个"input_type_ids" 这个feature(在训练时这个feature命名为segment_ids,用于区分text_a和text_b。可见https://github.com/google-research/bert/blob/master/run_classifier.py#L410)。 具体位置为: https://github.com/macanv/BERT-BiLSTM-CRF-NER/blob/12fb822bd847f24e4de1beca8af9de34d1a794d2/bert_base/server/graph.py#L343
解决方式:
segment_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'segment_ids')
同时修改loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=None, num_labels=num_labels)
segment_ids=features["input_type_ids"]
input_map = {"input_ids": input_ids, "input_mask": input_mask,"segment_ids":segment_ids}