Official implementation of GeDi: Generative Discriminator Guided Sequence Generation
Blogpost here
Sept 29, 2020: Adding support for GeDi-guided GPT-3 generation (API key needed)
GeDi is a method of using class-conditional language models (which we refer to as generative discriminators (GeDis)) to guide generation from other (potentially much larger) language models. This has several advantages over finetuning large language models directly including:
GeDi is a form of discriminator guided generation. A discriminator that can classify an attribute could be used to guide language model generation towards that attribute by classifying the sequences that result from candidate next tokens. However, using a normal discriminator (such as BERT) to do this would be very computationally expensive during generation, since it would require feeding in every candidate next token one-by-one to the discriminator to be classified. However, using generative discriminators, we can very efficiently classify candidate next tokens during generation using Bayes rule (see Section 3.1 of the paper). As an added bonus, generative discriminators can be used as zero shot classifiers, and can therefore be used to guide generation towards unseen topics.
Run scripts/setup.sh
:
cd scripts
bash setup.sh
This will install the following:
First download the models:
cd scripts
bash get_models.sh
This downloads and saves the topic, sentiment, and detoxifier models in the folder ../pretrained_models
To generate, use bash run_generation.sh
, which calls ../generate_GeDi.py
with the appropriate arguments (set for topic generation by default).
Important arguments include:
--mode
can be set to topic
, sentiment
, or detoxify
--gen_type
can be set to gedi
for GeDi guided generation, cclm
for class conditional generation, or gpt2
to generate from raw GPT-2--gen_length
max length of generation--gedi_model_name_or_path
path to GeDi model. If unused, will assume you ran bash get_models.sh
and infer model directory from --mode
argument--filter_p
equal to 1 - \rho in Equation 7 of the paper--target_p
equal to \tau from the paper--disc_weight
exponent for posterior weighting (\omega in Equation 6 of the paper)--fp16
converts GPT2-XL weights to fp16 for faster generation and less GPU memory usageRunning will allow you to enter control codes and prompts for generation in a continuous loop until you exit.
--mode topic
in scripts/run_generation.sh
world
, sports
, business
, and science
, but can often generate other topics zero-shot, for instance space
, fire
, climate
, education
--mode sentiment
in scripts/run_generation.sh
n
. Note that the negative model can be very negative, and this sometimes results in toxic or offensive samples.--mode detoxify
in scripts/run_generation.sh
--gen_type gpt2
to generate from GPT-2, and --gen_type cclm
to generate directly from the GeDi as a class-conditional language model. --gen_type cclm
corresponds to all experiments in Section 5 of the paper, and the CC-LM baselines in Section 6.1.To control sentiment from GPT-3 using your API key (should have prefix "sk"):
pip install openai
python ../generate_GeDi.py --penalize_cond --gen_length 100 --mode sentiment --gpt3_api_key sk-xxxxxxxx
You can also try changing the --mode
or other arguments. To generate directly from GPT-3 without GeDi using our same greedy decoding scheme:
python ../generate_GeDi.py --penalize_cond --gen_length 100 --mode sentiment --gen_type gpt2 --gpt3_api_key sk-xxxxxxx
cd scripts
bash get_data.sh
bash run_training.sh
which calls ../train_GeDi.py
with the appropriate arguments
output_dir
argument.../generate_GeDi.py
(called from bash run_generation.sh
) with --gedi_model_name_or_path
set to the directory of your trained model.@article{KrauseGeDi2020,
title={{GeDi: Generative Discriminator Guided Sequence Generation}},
author={Krause, Ben and Gotmare, Akhilesh Deepak and McCann, Bryan and Keskar, Nitish Shirish and Joty, Shafiq and Socher, Richard and Rajani, Nazneen Fatema},
journal={arXiv preprint arXiv:2009.06367},
year={2020}
}
The code is released under the BSD-3 License (see LICENSE.txt
for details), but we also ask that users respect the following:
This software should not be used to promote or profit from violence, hate, and division, environmental destruction, abuse of human rights, or the destruction of people's physical and mental health.