A simple Python package that wraps existing model fine-tuning and generation scripts for OpenAI's GPT-2 text generation model (specifically the "small" 124M and "medium" 355M hyperparameter versions). Additionally, this package allows easier generation of text, generating to a file for easy curation, allowing for prefixes to force the text to start with a given phrase.
This package incorporates and makes minimal low-level changes to:
For finetuning, it is strongly recommended to use a GPU, although you can generate using a CPU (albeit much more slowly). If you are training in the cloud, using a Colaboratory notebook or a Google Compute Engine VM w/ the TensorFlow Deep Learning image is strongly recommended. (as the GPT-2 model is hosted on GCP)
You can use gpt-2-simple to retrain a model using a GPU for free in this Colaboratory notebook, which also demos additional features of the package.
Note: Development on gpt-2-simple has mostly been superceded by aitextgen, which has similar AI text generation capabilities with more efficient training time and resource usage. If you do not require using TensorFlow, I recommend using aitextgen instead. Checkpoints trained using gpt-2-simple can be loaded using aitextgen as well.
gpt-2-simple can be installed via PyPI:
pip3 install gpt-2-simple
You will also need to install the corresponding TensorFlow 2.X version (min 2.5.1) for your system (e.g. tensorflow
or tensorflow-gpu
).
An example for downloading the model to the local system, finetuning it on a dataset. and generating some text.
Warning: the pretrained 124M model, and thus any finetuned model, is 500 MB! (the pretrained 355M model is 1.5 GB)
import gpt_2_simple as gpt2
import os
import requests
model_name = "124M"
if not os.path.isdir(os.path.join("models", model_name)):
print(f"Downloading {model_name} model...")
gpt2.download_gpt2(model_name=model_name) # model is saved into current directory under /models/124M/
file_name = "shakespeare.txt"
if not os.path.isfile(file_name):
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
data = requests.get(url)
with open(file_name, 'w') as f:
f.write(data.text)
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
file_name,
model_name=model_name,
steps=1000) # steps is max number of training steps
gpt2.generate(sess)
The generated model checkpoints are by default in /checkpoint/run1
. If you want to load a model from that folder and generate text from it:
import gpt_2_simple as gpt2
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess)
gpt2.generate(sess)
As with textgenrnn, you can generate and save text for later use (e.g. an API or a bot) by using the return_as_list
parameter.
single_text = gpt2.generate(sess, return_as_list=True)[0]
print(single_text)
You can pass a run_name
parameter to finetune
and load_gpt2
if you want to store/load multiple models in a checkpoint
folder.
There is also a command-line interface for both finetuning and generation with strong defaults for just running on a Cloud VM w/ GPU. For finetuning (which will also download the model if not present):
gpt_2_simple finetune shakespeare.txt
And for generation, which generates texts to files in a gen
folder:
gpt_2_simple generate
Most of the same parameters available in the functions are available as CLI arguments, e.g.:
gpt_2_simple generate --temperature 1.0 --nsamples 20 --batch_size 20 --length 50 --prefix "<|startoftext|>" --truncate "<|endoftext|>" --include_prefix False --nfiles 5
See below to see what some of the CLI arguments do.
NB: Restart the Python session first if you want to finetune on another dataset or load another model.
The method GPT-2 uses to generate text is slightly different than those like other packages like textgenrnn (specifically, generating the full text sequence purely in the GPU and decoding it later), which cannot easily be fixed without hacking the underlying model code. As a result:
truncate
parameter to a generate
function to only collect text until a specified end token. You may want to reduce length
appropriately.)prefix
targeting the beginning token sequences, and a truncate
targeting the end token sequence. You can also set include_prefix=False
to discard the prefix token while generating (e.g. if it's something unwanted like <|startoftext|>
)..csv
file to finetune()
, it will automatically parse the CSV into a format ideal for training with GPT-2 (including prepending <|startoftext|>
and suffixing <|endoftext|>
to every text document, so the truncate
tricks above are helpful when generating output). This is necessary to handle both quotes and newlines in each text document correctly.batch_size
that is divisible into nsamples
, resulting in much faster generation. Works very well with a GPU (can set batch_size
up to 20 on Colaboratory's K80)!batch_size=1
, and about 88% of the V100 GPU.overwrite=True
to finetune, which will continue training and remove the previous iteration of the model without creating a duplicate copy. This can be especially useful for transfer learning (e.g. heavily finetune GPT-2 on one dataset, then finetune on other dataset to get a "merging" of both datasets).gpt2.encode_dataset(file_path)
. THe output is a compressed .npz
file which will load much faster into the GPU for finetuning.gpt2.load_gpt2(sess, model_name='774M')
and gpt2.generate(sess, model_name='774M')
.Max Woolf (@minimaxir)
Max's open-source projects are supported by his Patreon. If you found this project helpful, any monetary contributions to the Patreon are appreciated and will be put to good creative use.
MIT
This repo has no affiliation or relationship with OpenAI.