salesforce / GeDi

GeDi: Generative Discriminator Guided Sequence Generation
https://arxiv.org/abs/2009.06367
BSD 3-Clause "New" or "Revised" License
208 stars 47 forks source link

GeDi logo

Official implementation of GeDi: Generative Discriminator Guided Sequence Generation

Blogpost here

Colab Notebook on controlling topic using GeDi here

Open In Colab

Updates

Sept 29, 2020: Adding support for GeDi-guided GPT-3 generation (API key needed)

Introduction

GeDi is a method of using class-conditional language models (which we refer to as generative discriminators (GeDis)) to guide generation from other (potentially much larger) language models. This has several advantages over finetuning large language models directly including:

GeDi is a form of discriminator guided generation. A discriminator that can classify an attribute could be used to guide language model generation towards that attribute by classifying the sequences that result from candidate next tokens. However, using a normal discriminator (such as BERT) to do this would be very computationally expensive during generation, since it would require feeding in every candidate next token one-by-one to the discriminator to be classified. However, using generative discriminators, we can very efficiently classify candidate next tokens during generation using Bayes rule (see Section 3.1 of the paper). As an added bonus, generative discriminators can be used as zero shot classifiers, and can therefore be used to guide generation towards unseen topics.

Dependencies

Generating from models used in paper

Important arguments include:

Running will allow you to enter control codes and prompts for generation in a continuous loop until you exit.

Topic generation (Section 6.3 & 6.4 of the paper)

Sentiment control (Section 6.1 of the paper)

Detoxication (Section 6.2 of the paper)

Class-conditional LM and GPT-2 generation

GPT-3 generation (added after paper, API access needed)

To control sentiment from GPT-3 using your API key (should have prefix "sk"):

pip install openai

python ../generate_GeDi.py --penalize_cond --gen_length 100 --mode sentiment --gpt3_api_key sk-xxxxxxxx

You can also try changing the --mode or other arguments. To generate directly from GPT-3 without GeDi using our same greedy decoding scheme:

python ../generate_GeDi.py --penalize_cond --gen_length 100 --mode sentiment --gen_type gpt2 --gpt3_api_key sk-xxxxxxx

Train your own GeDi

cd scripts
bash get_data.sh

bash run_training.sh which calls ../train_GeDi.py with the appropriate arguments

Citation

@article{KrauseGeDi2020,
  title={{GeDi: Generative Discriminator Guided Sequence Generation}},
  author={Krause, Ben and Gotmare, Akhilesh Deepak and McCann, Bryan and Keskar, Nitish Shirish and Joty, Shafiq and Socher, Richard and Rajani, Nazneen Fatema},
  journal={arXiv preprint arXiv:2009.06367},
  year={2020}
}

License

The code is released under the BSD-3 License (see LICENSE.txt for details), but we also ask that users respect the following:

This software should not be used to promote or profit from violence, hate, and division, environmental destruction, abuse of human rights, or the destruction of people's physical and mental health.