This repository contains the source code and trained model for "Joint Retrieval and Generation Training for Grounded Text Generation". RetGen is a joint training framework that simultaneously optimizes a dense passage retriever and a knowledge-grounded text generator in an end-to-end fashion. It can be applied to scenarios including but not limited to conversational modeling, text generation and open-domain question answering. The code implementation is based on DialoGPT, Huggingface Transformers, DPR and ANCE. Our human evaluation results indicates that RetGen can generate more relevant, interesting and human-like text comparing to vanilla DialoGPT or GPT-2.
Figure: RetGen overview.
If this repo is helpful to your research, please cite our paper:
@article{zhang2021joint,
title={Joint Retrieval and Generation Training for Grounded Text Generation},
author={Zhang, Yizhe and Sun, Siqi and Gao, Xiang and Fang, Yuwei and Brockett, Chris and Galley, Michel and Gao, Jianfeng and Dolan, Bill},
journal={arXiv preprint arXiv:2105.06597},
year={2021}
}
For cuda 10.0, run
conda env create -f RetGen.yml
conda activate RetGen
conda install pytorch=1.4.0 torchvision cudatoolkit=10.0 -c pytorch
, then install apex by (download apex to somewhere else)
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
For cuda 10.1, simply run
conda install pytorch=1.5.0 torchvision cudatoolkit=10.1 -c pytorch
instead of
conda install pytorch=1.4.0 torchvision cudatoolkit=10.0 -c pytorch
Next, install Fairseq in somewhere else
git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable ./
Container was built by
docker build -f dockerfile.txt -t gdpt .
Activate container by
docker run --gpus all --ipc=host --rm -it --mount src=/your_source_code_dir,dst=/code,type=bind --mount src=/gdpt,dst=/gdpt,type=bind intersun/gdpt
The training corpus needs to be first compressed in to *.db file using command below
python dialogpt/prepro.py --corpus data/train.doc_ctx_rsp.txt --max_seq_len 512
Example training command for reddit data with 8 GPUs (each with 32GB VRAM)
python -m torch.distributed.launch --nproc_per_node=8 joint_training.py
--model_name_or_path configs
--init_checkpoint models/reddit_generator.pkl
--train_input_file data/reddit_train.db
--eval_input_file data/reddit_test.txt
--output_dir output/joint_reddit
--file_suffix joint_reddit
--train_batch_size 4
--gradient_accumulation_steps 2
--eval_batch_size 2
--num_optim_steps 16000
--encoder_model_type ance_roberta
--pretrained_model_cfg bert-base-uncased
--model_file models/reddit_retriever.pkl
--ctx_file data/wiki.txt
--num_shards 8
--batch_size 128
--n_docs 4
--encoding
--load_trained_model
For single card, please use following command
CUDA_VISIBLE_DEVICES=0 python joint_training.py
--model_name_or_path configs
--init_checkpoint models/reddit_generator.pkl
--train_input_file data/reddit_train.db
--eval_input_file data/reddit_test.txt
--output_dir output/joint_reddit
--file_suffix joint_reddit
--train_batch_size 2
--gradient_accumulation_steps 2
--eval_batch_size 2
--num_optim_steps 16000
--encoder_model_type ance_roberta
--pretrained_model_cfg bert-base-uncased
--model_file models/reddit_retriever.pkl
--ctx_file data/wiki.txt
--num_shards 1
--batch_size 128
--n_docs 2
--encoding
--load_trained_model
We release model checkpoints which can be directly used or further fine-tuned with customized dataset.
Model | arXiv | |
---|---|---|
RetGen generator | [link] | [link] |
RetGen retriever | [link] | [link] |
For the generator, you can find the corresponding configuration files (merges.txt
, config.json
, vocab.json
) in ./configs/*
.
The preprocessed wikipedia dump used in our work can be downloaded as in below. We also provide the raw text data for arXiv. For Reddit, due to copyright issue. We will release a script to automatically extract the training development and test data.
Data | Link | ||
---|---|---|---|
Wiki (2.5GB) | [link] | N/A | N/A |
arXiv (2.4GB) | [link] | ||
TBD |
We note that even with properly filtered Reddit dataset, sometimes our model can still generate moderately toxic/inappropriate responses. Due to this reason, we are unable to provide the inference code at this time. We are currently working on a controlled decoding method to prevent this system from toxic generation. Please stay tuned.
Please follow dialoGPT evaluation script in dialogpt/README.md.
We provide example for evaluation retriever. The evaluation is done by estimating the recall@10,30,50 over 10K samples. The required files can be downloaded here
CUDA_VISIBLE_DEVICES=0 python eval_checkpoint.py \
--eval_mode rank \
--encoder_model_type ance_roberta \
--pretrained_model_cfg bert-base-uncased \
--model_file models/reddit_retriever.pkl \
--qa_file data/2k_positive.txt \
--ctx_file data/10k.txt \
--n_docs 50 \
--batch_size 64 \
--shard_id 0 \
--num_shards 1 \
--load_trained_model \
--encoding
If running correctly, the system should output results as in below:
Validation results: recall@10:0.617
Validation results: recall@30:0.703
Validation results: recall@50:0.744
We further conduct human evaluations (500 examples for each methods, each example is evaluated by 3 human judges). The results show a strong evidence that our generation quality is better than vanilla DialoGPT/GPT-2, under this non-interactive Turing test:
Coherence:A and B, which is more relevant to, and coherent with the context?
Dataset | System A | A Wins (%) | Ties (%) | B Wins (%) | System B |
---|---|---|---|---|---|
RetGen 345M | 43.7 | 28.3 | 28.0 | DialoGPT 345M | |
arXiv | RetGen 345M | 32.1 | 41.7 | 26.3 | GPT-2 345M |
Informativeness: A and B, which is more informative (usually more specific content)? | Dataset | System A | A Wins (%) | Ties (%) | B Wins (%) | System B |
---|---|---|---|---|---|---|
RetGen 345M | 44.5 | 27.8 | 27.7 | DialoGPT 345M | ||
arXiv | RetGen 345M | 36.3 | 37.2 | 26.5 | GPT-2 345M |
Human-likeness: A and B, which is more likely to be generated by human rather than a machine? | Dataset | System A | A Wins (%) | Ties (%) | B Wins (%) | System B |
---|---|---|---|---|---|---|
RetGen 345M | 36.4 | 34.0 | 29.6 | DialoGPT 345M | ||
arXiv | RetGen 345M | 29.7 | 43.6 | 26.7 | GPT-2 345M |
Please contact DialoGPT@microsoft.com if you have any questions/suggestions. However, the response will be sporadic. Please expect delay.
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution.
This repository aims to facilitate research in large-scale pretraining for conversational data. This toolkit contains only part of the modeling machinery needed to actually produce a model weight file in a running dialog. On its own, this model provides only information about the weights of various text spans; in order for a researcher to actually use it, they will need to bring conversational data of their own and decode the response generation from the pretrained system. We are not responsible for any generation from the 3rd party utilization of the pretrained system.
If this repo is helpful to your research, please cite our paper:
@article{zhang2021joint,
title={Joint Retrieval and Generation Training for Grounded Text Generation},
author={Zhang, Yizhe and Sun, Siqi and Gao, Xiang and Fang, Yuwei and Brockett, Chris and Galley, Michel and Gao, Jianfeng and Dolan, Bill},
journal={arXiv preprint arXiv:2105.06597},
year={2021}
}