lancopku / SGM

Sequence Generation Model for Multi-label Classification (COLING 2018)
432 stars 112 forks source link

Sequence Generation Model for Multi-label Classification

This is the code for our paper SGM: Sequence Generation Model for Multi-label Classification [pdf]


Note

In general, this code is more suitable for the following application scenarios:


Requirements


Dataset

Our used RCV1-V2 dataset can be downloaded from google drive with this link. The structure of the folders on drive is:

Google Drive Root          # The compressed zip file
 |-- data                          # The unprocessed raw data files
 |    |-- train.src        
 |    |-- train.tgt
 |    |-- valid.src
 |    |-- valid.tgt
 |    |-- test.src
 |    |-- test.tgt
 |    |-- topic_sorted.json        # The json file of label set for evaluation
 |-- checkpoints                   # The pre-trained model checkpoints
 |    |-- sgm.pt
 |    |-- sgmge.pt

We found that the valid-set in the previous version is so small that the model tends to overfit the valid-set, resulting in unstable performance. Therefore, we have expanded the valid-set. In addition, we also filtered out samples that contain more than 500 words in the original RCV1-V2 dataset.


Reproducibility

We provide the pretrained checkpoints of the SGM model and the SGM+GE model on the RCV1-V2 dataset to help you to reproduce our reported experimental results. The detailed reproduction steps are as follows:


Training from scratch

Preprocessing

You can preprocess the dataset with the following command:

python3 preprocess.py \
    -load_data load_data_path \       # input file dir for the data
    -save_data save_data_path \       # output file dir for the processed data
    -src_vocab_size 50000             # size of the source vocabulary

Note that all data path must end with /. Other parameter descriptions can be found in preprocess.py


Training

You can perform model training with the following command:

python3 train.py -gpus gpu_id -config model_config -log save_path

All log files and checkpoints during training will be saved in save_path. The detailed parameter descriptions can be found in train.py


Testing

You can perform testing with the following command:

python3 predict.py -gpus gpu_id -data save_data_path -batch_size batch_size -log log_path

The predicted labels and evaluation scores will be stored in the folder log_path. The detailed parameter descriptions can be found in predict.py


Citation

If you use the above code for your research, please cite the paper:

@inproceedings{YangCOLING2018,
  author    = {Pengcheng Yang and
               Xu Sun and
               Wei Li and
               Shuming Ma and
               Wei Wu and
               Houfeng Wang},
  title     = {{SGM:} Sequence Generation Model for Multi-label Classification},
  booktitle = {Proceedings of the 27th International Conference on Computational
               Linguistics, {COLING} 2018, Santa Fe, New Mexico, USA, August 20-26,
               2018},
  pages     = {3915--3926},
  year      = {2018}
}