RowitZou / CG-nAR

EMNLP-2021 paper: Thinking Clearly, Talking Fast: Concept-Guided Non-Autoregressive Generation for Open-Domain Dialogue Systems.
MIT License
18 stars 1 forks source link

CG-nAR

Pytorch implementation of the EMNLP-2021 paper: Thinking Clearly, Talking Fast: Concept-Guided Non-Autoregressive Generation for Open-Domain Dialogue Systems.

Requirements

Environment

Data

The way to divide the persona dataset comes from https://github.com/squareRoot3/Target-Guided-Conversation. You can get the raw dataset at google drive. Download and unzip it into the directory resource.

Besides, we use glove.6B.300d to initialize the graph-embedding. You can download it from fastnlp. Download and unzip it into the directory resource.

Finally, run the script for data preprocessing:

bash src/persona_preprocess.sh

By default, the divided dataset will be put into the raw_data/persona directory and the graph-related data will be put into the graph_data/persona directory.

You can download the cleaned dataset at Baidu Pan(extract code: kbiz). Download and unzip it into the directory resource.

We use cn_bi_fastnlp_100d to initialize the graph-embedding. You can download it from fastnlp. Download and unzip it into the directory resource.

After that, run the script for data preprocessing to get the necessary data in the program:

bash src/weibo_preprocess.sh

By default, the dataset will be put into the raw_data/weibo directory and the graph-related data will be put into the graph_data/weibo directory.

Usage

  1. Pre-process
PYTHONPATH=. python ./src/preprocess.py -dataset persona -mode raw_to_json -raw_path raw_data/persona -save_path json_data/persona/persona -adj_file graph_data/persona/adj_matrix.txt -vertex_file graph_data/persona/vertex.txt -log_file logs/raw_to_json_persona.log
PYTHONPATH=. python ./src/preprocess.py -dataset persona -mode json_to_data -type train -raw_path json_data/persona -save_path torch_data/persona -tokenizer bert-base-uncased -adj_file graph_data/persona/adj_matrix.txt -vertex_file graph_data/persona/vertex.txt -log_file logs/json_to_data_persona.log -n_cpus 4
  1. Train
PYTHONPATH=. python ./src/main.py -mode train -data_path torch_data/persona/persona -model_path models/persona -log_file logs/persona.train.log -visible_gpus 0 -warmup_steps 8000 -lr 0.001 -train_steps 100000 -graph_emb_path graph_data/persona/graph_embedding.npy -tokenizer bert-base-uncased
  1. Validate
PYTHONPATH=. python ./src/main.py -mode validate -data_path torch_data/persona/persona -log_file logs/persona.val.log -test_all -alpha 0.95 -model_path models/persona -result_path results/persona/persona -test_start_from 10000 -visible_gpus 0 -test_batch_ex_size 50 -graph_emb_path graph_data/persona/graph_embedding.npy -tokenizer bert-base-uncased
  1. Test
    PYTHONPATH=. python ./src/main.py -mode test -data_path torch_data/persona/persona -log_file logs/persona.test.log -alpha 0.95 -test_from models/persona/model_step_100000.pt -result_path results/persona/persona -visible_gpus 0 -test_batch_ex_size 50 -graph_emb_path graph_data/persona/graph_embedding.npy -tokenizer bert-base-uncased
PYTHONPATH=. python ./src/preprocess.py -dataset weibo -mode json_to_data -type train -raw_path json_data/weibo -save_path torch_data/weibo -tokenizer bert-base-chinese -adj_file graph_data/weibo/adj_matrix.txt -vertex_file graph_data/weibo/vertex.txt -log_file logs/json_to_data_weibo.log -n_cpus 8
  1. Train

    PYTHONPATH=. python ./src/main.py -mode train -data_path torch_data/weibo/weibo -model_path models/weibo -log_file logs/weibo.train.log -visible_gpus 0 -warmup_steps 8000 -lr 0.001 -train_steps 100000 -graph_emb_path graph_data/weibo/graph_embedding.npy -tokenizer bert-base-chinese
  2. Validate

    PYTHONPATH=. python ./src/main.py -mode validate -data_path torch_data/weibo/weibo -log_file logs/weibo.val.log -test_all -alpha 0.95 -model_path models/weibo -result_path results/weibo/weibo -test_start_from 10000 -visible_gpus 0 -test_batch_ex_size 50 -graph_emb_path graph_data/weibo/graph_embedding.npy -tokenizer bert-base-chinese
  3. Test

    PYTHONPATH=. python ./src/main.py -mode test -data_path torch_data/weibo/weibo -log_file logs/weibo.test.log -alpha 0.95 -test_from models/weibo/model_step_100000.pt -result_path results/weibo/weibo -visible_gpus 0 -test_batch_ex_size 50 -graph_emb_path graph_data/weibo/graph_embedding.npy -tokenizer bert-base-chinese

Citation

@inproceedings{
    zou-etal-2021-thinking,
    title = "Thinking Clearly, Talking Fast: Concept-Guided Non-Autoregressive Generation for Open-Domain Dialogue Systems",
    author = "Zou, Yicheng  and Liu, Zhihua  and Hu, Xingwu  and Zhang, Qi",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
    month = nov,
    year = "2021",
    address = "Online and Punta Cana, Dominican Republic",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2021.emnlp-main.169",
    pages = "2215--2226"
}