wzzzd / text_classifier_pytorch

基于Pytorch的文本分类框架,支持TextCNN、Bert、Electra等。
57 stars 9 forks source link

Text_Classifier_Pytorch

Info

基于Pytorch的文本分类框架。

同时支持中英文的数据集的文本分类任务。

Model

Trianing Mode Support

Datasets

Experiments

说明:预训练模型基于transformers框架,如若想要替换成其他预训练参数,可以查看transformers官方网站

模型名称 MicroF1 LearningRate 预训练参数
FastText 0.8926 1e-3 -
TextCNN 0.9009 1e-3 -
TextRNN 0.9080 1e-3 -
TextRCNN 0.9142 1e-3 -
Tramsformer(2 layer) 0.8849 1e-3 -
Albert 0.9124 2e-5 voidful/albert_chinese_tiny
Distilbert 0.9209 2e-5 Geotrend/distilbert-base-zh-cased
Bert 0.9401 2e-5 bert-base-chinese
Roberta 0.9448 2e-5 hfl/chinese-roberta-wwm-ext
Electra 0.9377 2e-5 hfl/chinese-electra-base-discriminator
XLNet 0.9051 2e-5 无参数初始化

Requirement

Python使用的是3.6.X版本,其他依赖模块如下:

    numpy==1.19.2
    pandas==1.1.5
    scikit_learn==1.0.2
    torch==1.8.0
    tqdm==4.62.3
    transformers==4.15.0
    apex==0.1

除了apex需要额外安装(参考官网:https://github.com/NVIDIA/apex ),其他模块可通过以下命令安装依赖包

    pip install -r requirement.txt

Get Started

1. 训练

准备好训练数据后,终端可运行命令

    python3 main.py

2 测试评估

加载已训练好的模型,并使用valid set作模型测试,输出文件到 ./dataset/${your_dataset}/output/output.txt 目录下。

需要修改Config文件中的变量值mode = 'test',并保存。

终端可运行命令

    python3 main.py

Reference

[Github:transformers] https://github.com/huggingface/transformers

[Paper:Bert] https://arxiv.org/abs/1810.04805

[Paper:RDrop] https://arxiv.org/abs/2106.14448

[Paper:SimCSE] https://arxiv.org/abs/2104.08821