lonePatient / Bert-Multi-Label-Text-Classification

This repo contains a PyTorch implementation of a pretrained BERT model for multi-label text classification.
MIT License
878 stars 207 forks source link
albert bert fine-tuning multi-label-classification nlp pytorch pytorch-implmention text-classification transformers xlnet

Bert multi-label text classification by PyTorch

This repo contains a PyTorch implementation of the pretrained BERT and XLNET model for multi-label text classification.

Structure of the code

At the root of the project, you will see:

├── pybert
|  └── callback
|  |  └── lrscheduler.py  
|  |  └── trainingmonitor.py 
|  |  └── ...
|  └── config
|  |  └── basic_config.py #a configuration file for storing model parameters
|  └── dataset   
|  └── io    
|  |  └── dataset.py  
|  |  └── data_transformer.py  
|  └── model
|  |  └── nn 
|  |  └── pretrain 
|  └── output #save the ouput of model
|  └── preprocessing #text preprocessing 
|  └── train #used for training a model
|  |  └── trainer.py 
|  |  └── ...
|  └── common # a set of utility functions
├── run_bert.py
├── run_xlnet.py

Dependencies

How to use the code

you need download pretrained bert model and xlnet model.

BERT: bert-base-uncased

XLNET: xlnet-base-cased

  1. Download the Bert pretrained model from s3
  2. Download the Bert config file from s3
  3. Download the Bert vocab file from s3
  4. Rename:

    • bert-base-uncased-pytorch_model.bin to pytorch_model.bin
    • bert-base-uncased-config.json to config.json
    • bert-base-uncased-vocab.txt to bert_vocab.txt
  5. Place model ,config and vocab file into the /pybert/pretrain/bert/base-uncased directory.
  6. pip install pytorch-transformers from github.
  7. Download kaggle data and place in pybert/dataset.
    • you can modify the io.task_data.py to adapt your data.
  8. Modify configuration information in pybert/configs/basic_config.py(the path of data,...).
  9. Run python run_bert.py --do_data to preprocess data.
  10. Run python run_bert.py --do_train --save_best --do_lower_case to fine tuning bert model.
  11. Run run_bert.py --do_test --do_lower_case to predict new data.

training

[training] 8511/8511 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] -0.8s/step- loss: 0.0640
training result:
[2019-01-14 04:01:05]: bert-multi-label trainer.py[line:176] INFO  
Epoch: 2 - loss: 0.0338 - val_loss: 0.0373 - val_auc: 0.9922

training figure

result

---- train report every label -----
Label: toxic - auc: 0.9903
Label: severe_toxic - auc: 0.9913
Label: obscene - auc: 0.9951
Label: threat - auc: 0.9898
Label: insult - auc: 0.9911
Label: identity_hate - auc: 0.9910
---- valid report every label -----
Label: toxic - auc: 0.9892
Label: severe_toxic - auc: 0.9911
Label: obscene - auc: 0.9945
Label: threat - auc: 0.9955
Label: insult - auc: 0.9903
Label: identity_hate - auc: 0.9927

Tips