Codebase to train a CLIP conditioned Text to Image Diffusion model on Colab in Keras. See below for notebooks and examples with prompts.
Images generated for the prompt: A small village in the Alps, spring, sunset
Images generated for the prompt: Portrait of a young woman with curly red hair, photograph
(more exampes below - try with your own inputs in Colab here: )
The goal of this repo is to provide a simple, self-contained codebase for Text to Image Diffusion that can be trained in Colab in a reasonable amount of time.
While there are a lot of great resources around the math and usage of diffusion models I haven't found many specifically focused on training text to img diffusion models. Particularly the idea of training a Dall-E 2 or Stable Diffusion like model feels like a daunting task requiring immense computational resources and data. Turns out you can accomplish quite a lot with little resources and without having a PhD in thermodynamics! The easiest way to get aquainted with the code is thru the notebooks below.
denoiser.py
is based on this. I have added
additional text/CLIP/masking embeddings/inputs and cross/self attention.If you are just starting out I recommend trying out the next two notebook in order. The first should be able to get you recognizable images on the Fashion Mnist dataset within minutes!
Train Class Conditional Fashion MNIST/CIFAR
file_name
. You can get reasonable results after 25 epochs for CIFAR 10 and 40 epochs for CIFAR 100.
Training 50-100 epochs is even better. Train CLIP Conditioned Text to Img Model on 115k 64x64 images+prompts sampled from the Laion Aesthetics 6.5+ dataset.
Test Prompts on a model trained for about 60 epochs (~60 hours on 1 V100) on entire 600k Laion Aesthetics 6.5+.
The model architecture, training parameters, and generation parameters are specified in a yaml file see here for examples. If unsure you can use the base_model. The get_train_data is built to work with various known datasets. If you have
your own dataset you can just use that instead. train_label_embeddings
is expected to be a matrix of embedding the model conditions on (usually some embedding of text but could be anything).
config_path = "guided-diffusion-keras/guided_diffusion/configs/base_model.yaml"
trainer = Trainer(config_path)
print(trainer.__dict__)
train_data, train_label_embeddings = get_train_data(trainer.file_name) #OR get your own images and label embeddings in matrix form.
trainer.preprocess_data(train_data, train_label_embeddings)
trainer.initialize_model()
trainer.data_checks(train_data)
trainer.train()
The setup is fairly simple.
We train a denoising U-NET neural net that takes the following three inputs:
noise_level
(sampled from 0 to 1 with more values concentrated close to 0)noise_level
between 0 and 1 the corruption is as follows:x_noisy = x*(1-noise_level) + eps*noise_level where eps ~ np.random.normal
The output is a prediction of the denoised image - call it f(x_noisy)
.
The model is trained to minimize the absolute error |f(x_noisy) - x|
between the prediction and actual image
(you can also use squared error here). Note that I don't reparametrize the loss in terms of the noise here to keep things simple.
Using this model we then iteratively generate an image from random noise as follows:
for i in range(len(self.noise_levels) - 1):
curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]
# predict original denoised image:
x0_pred = self.predict_x_zero(new_img, label, curr_noise)
# new image at next_noise level is a weighted average of old image and predicted x0:
new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
The predict_x_zero
method uses classifier free guidance by combining the conditional and unconditional
prediction: x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional
A bit of math: The approach above falls within the VDM parametrization see 3.1 in Kingma et al.:
$$ z_t = \alpha_tx + \sigma_t\epsilon, \epsilon ~ n(0,1)$$
Where $z_t$ is the noisy version of $x$ at time $t$.
generally $\alpha_t$ is chosen to be $\sqrt{1-\sigma_t^2}$ so that the process is variance preserving. Here I chose $\alpha_t=1-\sigma_t$ so that we linearly interpolate between the image and random noise. Why? Honestly I just wondered if it was going to work :) also it simplifies the updating equation quite a bit and it's easier to understand what the noise to signal ratio will look like. The updating equation above is the DDIM model for this parametrization which simplifies to a simple weighted average. Note that the DDIM model deterministically maps random normal noise to images - this has two benefits: we can interpolate in the random normal latent space, it takes fewer steps generaly to get decent image quality.
Note that I use a lot of unorthodox choices in the modelling. Since I am fairly new to generative models I found this to be a great way to learn what is crucial vs. what is nice to have. I generally did not see any divergences in training which supports the notion that diffusion models are stable to train and are fairly robust to model choices. The flipside of this is that if you introduce subtle bugs in your code (of which I am sure there are many in this repo) they are pretty hard to spot.
Architecture: TODO - add cross-attention description.
The text-to-img models use the Laion 6.5+ datasets. You can see some samples here. As you can see this dataset is very biased towards landscapes and portraits. Accordingly, the model does best at prompts related to art/landscapes/paintings/portraits/architecture.
The script img_load_utils.py
contains some code to use the img2dataset package to
download and store images, texts, and their corresponding embeddings. The Laion datasets are still
quite messy with a lot of duplicates, bad descriptions etc.
This can be used to quickly prototype new generative models. This dataset is also used in the notebook above.
TODO: add more info and script on how to preprocess the data and link to huggingface repo. Talk about data quality issues.
If you want to train the img-to-text model I highly recommend getting at least the Colab Pro or even the Colab Pro+ - it's going to be hard to train the model on a K80 GPU, unfortunately. NOTE: Colab will change its setup and introduce credits at the end of September - I will update this.
Setting this training workflow on Google Colab wasn't too bad. My approach has been very low tech and Google Drive played a large role. Basically at the end of every epoch I save the model and the generated images on a small validation set (50-100 images) to Drive.
This has a few advantages:
I have slowly moved some data/models on huggingface but this is WIP.
In terms of speed the GPUs go as follows:
A100>V100>P100>T4>K80
with the A100 being the fastest and every subsequent GPU being roughly twice as slow as
the one before it for training (e.g. P100 is about 4x slower than A100). While I did get the A100 a few times
the sweet spot was really V100/P100 on Colab Pro+ since the risk of being time-outed decreased. With colab PRO+ ($50/month) I managed to train on V100/P100 continuously for 12-24 hours at a time.
I'm not an expert here but generally the validation of generative models is still an open question. There are metrics like Inception Score, FID, and KID that measure whether the distribution of generated images is "close" to the training distribution in some way. The main issue with all of these metrics however is that a model that simply memorizes the training data will have a perfect score - so they don't account for overfitting. They are also fairly hard to understand, need large sample sizes, and are computationally intensive. For all these reasons I have chosen not to use them for now.
Instead I have focused on analyze the visual quality of generated images by uhm.. looking at them. This can quickly devolved into a tea-lead reading exercise however. To combat this one come up with different strategies to test for quality and diversity. For example sampling from both generated and ground truth images and looking at them together
To test for generalization I have mostly focused on interpolations in both the CLIP space and the random normal latent space. Ideally as you move from embedding to embedding you want to generated images along the path to be meaningful in some way.
CLIP interpolation: "A lake in the forest in the summer" -> "A lake in the forest in the winter"
Does the model memorize the training data? This is an important question that has lots of implications. First of all the models above don't have the capacity to memorize all of the training data. For example: the model is about 150 MB but is trained on about 8GB of data. Second of all it might not be in the model's best interest to memorize things. After digging a bit around the predictions on the training data I did find one example where the model shamelessly copies a training example. Note this is because the image appears many times in the training data.
Prompt: An Italian Villaga Painted by Picasso
City at night
Photograph of young woman in a field of flowers, bokeh
Street on an island in Greece
A Mountain Lake in the spring at sunset
A man in a suit in the field in wintertime
CLIP interpolation: "A minimalist living room" -> "A Field in springtime, painting"
CLIP interpolation: "A lake in the forest in the summer" -> "A lake in the forest in the winter"