基于Pytorch的文本分类框架。
同时支持中英文的数据集的文本分类任务。
Config.py
中的变量model_name
表示模型名称,可以更改成你想要加载的模型名称。initial_pretrain_model
和initial_pretrain_tokenizer
,修改为你想要加载的预训练参数。Config.py
中的变量fp16
值改为True
。Config.py
中的变量cuda_visible_devices
用于设置可见的GPU卡号,多卡情况下用,
间隔开。Config.py
中的变量adv_option
用于设置可见的对抗模式,目前支持FGM/PGD。Config.py
中的变量cl_option
设置为True
则表示开启对比学习模式,cl_method
用于设置计算对比损失的方法。THUCNews
加入自己的数据集
\t
分割。数据集示例
午评沪指涨0.78%逼近2800 汽车家电农业领涨 2
卡佩罗:告诉你德国脚生猛的原因 不希望英德战踢点球 7
说明:预训练模型基于transformers框架,如若想要替换成其他预训练参数,可以查看transformers官方网站。
模型名称 | MicroF1 | LearningRate | 预训练参数 |
---|---|---|---|
FastText | 0.8926 | 1e-3 | - |
TextCNN | 0.9009 | 1e-3 | - |
TextRNN | 0.9080 | 1e-3 | - |
TextRCNN | 0.9142 | 1e-3 | - |
Tramsformer(2 layer) | 0.8849 | 1e-3 | - |
Albert | 0.9124 | 2e-5 | voidful/albert_chinese_tiny |
Distilbert | 0.9209 | 2e-5 | Geotrend/distilbert-base-zh-cased |
Bert | 0.9401 | 2e-5 | bert-base-chinese |
Roberta | 0.9448 | 2e-5 | hfl/chinese-roberta-wwm-ext |
Electra | 0.9377 | 2e-5 | hfl/chinese-electra-base-discriminator |
XLNet | 0.9051 | 2e-5 | 无参数初始化 |
Python使用的是3.6.X版本,其他依赖模块如下:
numpy==1.19.2
pandas==1.1.5
scikit_learn==1.0.2
torch==1.8.0
tqdm==4.62.3
transformers==4.15.0
apex==0.1
除了apex
需要额外安装(参考官网:https://github.com/NVIDIA/apex
),其他模块可通过以下命令安装依赖包
pip install -r requirement.txt
准备好训练数据后,终端可运行命令
python3 main.py
加载已训练好的模型,并使用valid set作模型测试,输出文件到 ./dataset/${your_dataset}/output/output.txt 目录下。
需要修改Config文件中的变量值mode = 'test'
,并保存。
终端可运行命令
python3 main.py
[Github:transformers] https://github.com/huggingface/transformers
[Paper:Bert] https://arxiv.org/abs/1810.04805
[Paper:RDrop] https://arxiv.org/abs/2106.14448
[Paper:SimCSE] https://arxiv.org/abs/2104.08821