t6am3 / public_tianchi_yiqing_nlp

天池疫情公益文本相似对比大赛
20 stars 11 forks source link

“公益AI之星”挑战赛-新冠疫情相似句对判定大赛 解决方案

Team:tbam3 youfeng@buaa.edu.cn final LB:4th(96.30) 比赛地址

Index

  1. 算法说明
  2. 代码说明
  3. 运行环境
  4. 运行说明
  5. 参考资料

    1. 算法说明

    本解决方案使用了基于病名\药名的数据增强+模型融合+训练时-测试时增强+伪标签的解决方案

    • 基于病名\药名的数据增强 Data augmentation

根据比赛组织方的信息,总共肺炎”、“支原体肺炎”、“支气管炎”、“上呼吸道感染”、“肺结核”、“哮喘”、“胸膜炎”、“肺气肿”、“感冒”、“咳血”十个病种,但是在train和dev数据集中仅仅出现了八个病种,其他的两个“肺结核”与“支气管炎”病种并没有出现,推测在test中包括了剩下的两个病种,是这次比赛的一个关键信息。

本次比赛需要模型学习的内容主要包括以下几个点:匹配语义信息,病名信息,药名信息,病理信息,我们需要针对这四个点来进行数据增强。

在测试集中,“肺结核”和“支气管炎”两个病种的测试数据中显然含有我们已有标注数据没有的病名、药名信息,但是这些信息是较为易得的;对于语义匹配信息和病理信息,1. 其生成难度要远远高于前两者,2.且很可能改变原数据集中的语义匹配和病理信息,出于这两点考虑,本解决方案采取了替换原数据中病名\药名的数据增强。

在实现过程中,挑选了病理与“肺结核”、“支气管炎”较为接近的“支原体肺炎”与“哮喘”标注数据中的部分样本,作病名替换,添加到原始标注数据中作为训练数据集。LB上升1.9个千分点(96.10->96.29)

本解决方案使用了ernie + bert_wwm_ext + roberta_large_pair的融合模型,对最后的结果使用平均值。具体的来源和下载地址见参考资料。提升2.5个千分点(95.75->96.10)

本解决方案中,在预测时,首先用原测试集预测一遍标签;然后将原测试集的query1和query2字段交换,再次预测一遍;最后将两个结果相加作为最后的预测结果。出于训练时模型拟合方向的偏差考虑,在训练时也训练了两种模型,分别用于预测正序\逆序时的数据集,这一做法的提高非常稳定。

这样的技巧是为了让模型在学习\预测过程中看到数据的更多方面,结合数据中包含的边角信息。LB上升2个千分点(95.59->95.75)

注:这个地方的提升不仅是添加了train-test time augmentation, 另外考虑时间因素移除了pseudo_label, 故估计实际上升为2个千分点左右。

在预测完成后,使用预测结果和原训练集一起作为新的训练集再次训练一个模型做预测。LB上升1个万分点(96.29->96.30)

主要提升的过程

algo LB
bert-base 94.45
ernie 95.08
ernie + pseudo_label 95.16
ernie + bert-base + cwe + pseudo_label 95.59
ernie + bert-base + cwe + train-test_time_aug 95.75
ernie + cwe + roberta-large-pair + train-test_time_aug 96.10
erinie + cwe + roberta-large-pair + train-test_time_aug + oov_sick_data_augmentation 96.29
erinie + cwe + roberta-large-pair + train-test_time_aug + oov_sick_data_augmentation + pseudo_label 96.30

题外话

2. 代码说明

dependency version
pytorch 1.3.1
cuda 10.1.243
numpy 1.17.3
pandas 0.25.3
transformers 2.5.1
tqdm 4.36.1

4. 运行说明

5. 参考资料

预训练模型来源

先验医药知识来源