PyTorch re-implementation of some text classificaiton models.
Train the following models by editing model_name
item in config files (here are some example config files). Click the link of each for details.
Hierarchical Attention Networks (HAN) (han
)
Hierarchical Attention Networks for Document Classification. Zichao Yang, et al. NAACL 2016. [Paper]
fastText (fasttext
)
Bag of Tricks for Efficient Text Classification. Armand Joulin, et al. EACL 2017. [Paper] [Code]
Bi-LSTM + Attention (attbilstm
)
Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification. Peng Zhou, et al. ACL 2016. [Paper]
TextCNN (textcnn
)
Convolutional Neural Networks for Sentence Classification. Yoon Kim. EMNLP 2014. [Paper] [Code]
Transformer (transformer
)
Attention Is All You Need. Ashish Vaswani, et al. NIPS 2017. [Paper] [Code]
First, make sure your environment is installed with:
Then install requirements:
pip install -r requirements.txt
Currently, the following datasets proposed in this paper are supported:
All of them can be download here (Google Drive). Click here for details of these datasets.
You should download and unzip them first, then set their path (dataset_path
) in your config files. If you would like to use other datasets, they may have to be stored in the same format as the above mentioned datasets.
If you would like to use pre-trained word embeddings (like GloVe), just set emb_pretrain
to True
and specify the path to pre-trained vectors (emb_folder
and emb_filename
) in your config files. You could also choose to fine-tune word embeddings or not with by editing fine_tune_embeddings
item.
Or if you want to randomly initialize the embedding layer's weights, set emb_pretrain
to False
and specify the embedding size (embed_size
).
Although torchtext can be used to preprocess data easily, it loads all data in one go and occupies too much memory and slows down the training speed, expecially when the dataset is big.
Therefore, here I preprocess the data manually and store them locally first (where configs/test.yaml
is the path to your config file):
python preprocess.py --config configs/example.yaml
Then I load data dynamically using PyTorch's Dataloader when training (see datasets/dataloader.py
).
The preprocessing including encoding and padding sentences and building word2ix map. This may takes a little time, but in this way, the training can occupy less memory (which means we can have a large batch size) and take less time. For example, I need 4.6 minutes (on RTX 2080 Ti) to train a fastText model on Yahoo Answers dataset for an epoch using torchtext, but only 41 seconds using Dataloader.
torchtext.py
is the script for loading data via torchtext, you can try it if you have interests.
To train a model, just run:
python train.py --config configs/example.yaml
If you have enabled the tensorboard (tensorboard: True
in config files), you can visualize the losses and accuracies during training by:
tensorboard --logdir=<your_log_dir>
Test a checkpoint and compute accuracy on test set:
python test.py --config configs/example.yaml
To predict the category for a specific sentence:
First edit the following items in classify.py
:
checkpoint_path = 'str: path_to_your_checkpoint'
# pad limits
# only makes sense when model_name == 'han'
sentence_limit_per_doc = 15
word_limit_per_sentence = 20
# only makes sense when model_name != 'han'
word_limit = 200
Then, run:
python classify.py
Here I report the test accuracy (%) and training time per epoch (on RTX 2080 Ti) of each model on various datasets. Model parameters are not carefully tuned, so better performance can be achieved by some parameter tuning.
Model | AG News | DBpedia | Yahoo Answers |
---|---|---|---|
Hierarchical Attention Network | 92.7 (45s) | 98.2 (70s) | 74.5 (2.7m) |
fastText | 91.6 (8s) | 97.9 (25s) | 66.7 (41s) |
Bi-LSTM + Attention | 92.0 (50s) | 99.0 (105s) | 73.5 (3.4m) |
TextCNN | 92.2 (24s) | 98.5 (100s) | 72.8 (4m) |
Transformer | 92.2 (60s) | 98.6 (8.2m) | 72.5 (14.5m) |
load_embeddings
method (in utils/embedding.py
) would try to create a cache for loaded embeddings under folder dataset_output_path
. This dramatically speeds up the loading time the next time.
This project is based on sgrvinod/a-PyTorch-Tutorial-to-Text-Classification.