kohjingyu / gill

🐟 Code and models for the NeurIPS 2023 paper "Generating Images with Multimodal Language Models".
https://jykoh.com/gill
Apache License 2.0
433 stars 38 forks source link
computer-vision large-language-models machine-learning natural-language-processing

Generating Images with Multimodal Language Models

GILL chat animation

This repository hosts the code and model weights for the GILL model.

Paper | Project Webpage

HF paper page Open in Spaces

Model and Usage

GILL model architecture

GILL (Generating Images with Large Language Models) is capable of processing arbitrarily interleaved image-and-text inputs to generate text, retrieve images, and generate novel images.

Setup instructions

Environment

Set up a new virtualenv, and install required libraries:

python -m venv venv
source venv/bin/activate
pip install -r requirements.txt

Add the gill library to PYTHONPATH:

export PYTHONPATH=$PYTHONPATH:/home/path/to/gill/

Pretrained Checkpoints

The GILL model weights (linear layers and [IMG] embeddings) are small (around 96MB), and are included in this git repo. They will be in the checkpoints/gill_opt/ folder after cloning. The checkpoint and model config in checkpoints/gill_opt/ reproduce the main results reported in our paper.

Precomputed Embeddings For Image Retrieval

For image retrieval, we provide the precomputed visual embeddings for Conceptual Captions images with valid URLs. They are stored at this URL. These are used to enable the model to retrieve images. The embeddings take up around 3GB, and are compatible with both model configs we provide. Download the files and place cc3m_embeddings_urls.npy into the checkpoints/gill_opt/ directory.

Note that you can still run the model without these, but it will not produce retrieved images. It will always generate novel images!

If you wish to precompute these embeddings for a different set of image URLs or for a different model, edit scripts/extract_img_embs.py with the list of image URLs and run it:

python scripts/extract_img_embs.py

Inference

Check out GILL_example_notebook.ipynb for examples on calling the model for inference. Several of the figures presented in the paper are reproduced in this notebook using greedy decoding of the model. Note that there may be minor differences in image outputs due to CC3M images being lost over time.

The notebook also shows how to use the model for generating images and generating text.

Training

Preparing CC3M

Our model is trained on the Conceptual Captions dataset. After following the instructions on the website to download the captions and images, format it into a .tsv file as follows:

caption image
A picture of a cat  cat.png
Mountains  mountain.png

where each line contains the caption followed by the filename of the image files. Save these .tsv files into the dataset/ folder (the default names expected are cc3m_train.tsv and cc3m_val.tsv). The repo contains two placeholder files with a few examples, and you will have to replace them with the appropriate data.

The corresponding image files should be saved in the data/ directory. The directory can be changed with the --image-dir runtime flag.

If you need help downloading CC3M for GILL, this repo contains helpful step-by-step tips.

Precomputing Text Embeddings

In addition to downloading the images, GILL also requires the embeddings from the text encoder of Stable Diffusion to train. We precompute this ahead of time in order to improve training time throughput. To do so, run the following script:

python scripts/preprocess_sd_embeddings.py  datasets/cc3m_val.tsv data/cc3m/validation/clip_embs

This will precompute embeddings from the captions in cc3m_val.tsv, and save the results to data/cc3m/validation/clip_embs.

Starting a Training Job

After preprocessing the data, we can finally start a training job with the following command line flag:

randport=$(shuf -i8000-9999 -n1)  # Generate a random port number
python -u main.py \
    --dist-url "tcp://127.0.0.1:${randport}" --dist-backend 'nccl' \
    --multiprocessing-distributed --world-size 1 --rank 0 \
    --dataset=cc3m  --val-dataset=cc3m \
    --exp-name='gill_exp' --image-dir='data/'  --log-base-dir='runs/' \
    --precision='bf16'  --print-freq=100

The default hyperparameters in main.py should reproduce our main results in the paper. We train on 2 A6000 GPUs for 48 hours. For GPUs with smaller memory available, you might need to reduce the batch size, enable gradient accumulation, or adjust hyperparameters to get good performance. You may also have to disable NCCL P2P with export NCCL_P2P_DISABLE=1 if you run into issues.

You can also run a small job on CPU, for testing purposes:

python -u main.py \
    --dataset=cc3m  --val-dataset=cc3m \
    --opt-version='facebook/opt-125m' --visual-model='openai/clip-vit-base-patch16' \
    --exp-name='gill_exp'   --log-base-dir='runs/' \
    --batch-size=2  --val-batch-size=2  --precision='fp32'  --print-freq=1 \
    --epochs=2  --val_steps_per_epoch=2   --steps_per_epoch=2

Pruning the Checkpoint

As GILL only consists of a few pretrained linear layers and the [IMG] embeddings, we can discard most of the pretrained weights to save on disk space. If you have trained a new model, and wish to do so, you can use gill/prune_model_ckpt.py file to prune the model weights, and format the ckpt as required by gill/models.py:

python scripts/prune_model_ckpt.py  runs/gill_exp

We used the same script to create the weights in the checkpoints/ directory.

Training a Decision Classifier

As described in the paper (Appendix F), we annotate PartiPrompts with per-example labels to retrieve or generate. The annotations are provided in data/PartiPromptsAllDecisions.tsv. The format follows PartiPrompts, with an additional Decisions column that we introduce:

Prompt  Category    Challenge   Note    Decisions
bond    Abstract    Basic   Biology-inspired concepts with multiple meanings    ret,gen,gen,same,gen
element Abstract    Basic   Biology-inspired concepts with multiple meanings    ret,ret,ret,ret,same

this column indicates the annotations of 5 independent human evaluators. The decisions indicate whether the annotators prefer the retrieved image (ret), Stable Diffusion generated image (gen), or if both are around the same (same). The annotations released are for the query assessing which image is more relevant to the provided prompt. The annotations for the query on realism is also available at data/PartiPromptsAllDecisions_Realism.tsv, although we recommend using the text alignment annotations for training a decision classifier (as retrieved images are likely to be significantly more realistic than generated ones in general).

To train a decision classifier, first, preprocess the PartiPrompts annotations to keep only those with high interannotator agreement:

python scripts/process_p2_annotations.py

To train a decision model on these annotations, please follow the steps in TrainDecisionClassifier.ipynb. F1 scores of the model and human baselines are reported in the notebook. If you trained a GILL model from scratch, you would need to train this classifier as well, as the one provided at checkpoints/gill_opt/decision_model.pth.tar is only compatible with our original model weights.

Evaluation

We provide code to reproduce the VIST (Table 1) and VisDial (Table 2) results presented in our paper.

VIST Evaluation

To run the VIST evaluation, first download the annotations from the val set of the official VIST dataset. We will need to download and process the image files for running the evaluations presented in the paper. This can be done by running python evals/download_vist_images.py. By default, images are saved to the sis/val_images/ directory. Downloading the images should take about 1 hour on a decent connection (as images are downloaded directly from the Flickr URLs).

After the image files are downloaded, we can run the VIST generation experiment described in Section 4.1 our paper. First, we will run GILL to generate the last image in the sequence, conditioned on image + text inputs:

python evals/generate_vist_images.py  gill_vist_outputs

The generated images for each VIST example will be saved in gill_vist_outputs/. Then, to benchmark the models, we can compute the CLIP similarity scores:

python evals/compute_clip_similarity_vist.py

For the LPIPS metric, please refer to their official GitHub repo for installation instructions. Then, we can compute the results as follows:

python evals/lpips_2dirs.py -d0  sis/val_images/  -d1  gill_vist_outputs  -o results.txt --use_gpu

For LPIPS, you may have to resize the images to 256x256 to match the AlexNet model used. We have also uploaded our LPIPS eval script (gill/evals/lpips_2dirs.py) for reference.

VisDial Evaluation

Similarly, for VisDial, download the VisDial validation annotations, the dense answer annotations, and the images. Extract everything to the VisualDialog folder.

We can run the VisDial generation experiment described in Section 4.1 our paper. We run GILL to generate an image conditioned on the full text dialogue input:

python evals/generate_visdial_images.py  gill_visdial_outputs

The generated images for each VisDial example will be saved in gill_visdial_outputs/. Then, to benchmark the models, we can compute the CLIP similarity scores:

python evals/compute_clip_similarity_visdial.py

For LPIPS, please follow the VIST instructions above to compute scores using the official LPIPS GitHub repo.

Gradio Demo

You can launch your own version of the Gradio demo locally by running python demo/app_gradio.py, or duplicating the HuggingFace space.

TODOs

Citation

If you find this work or our code useful, please consider citing:

@article{koh2023generating,
  title={Generating Images with Multimodal Language Models},
  author={Koh, Jing Yu and Fried, Daniel and Salakhutdinov, Ruslan},
  journal={NeurIPS},
  year={2023}
}