This repo contains code for Unified-IO 2, including code to run a demo, do training, and do inference. This codebase is modified from T5X.
[2/15/2024] We release the Pytorch code for unified-io 2. Details can be found here
[1/5/2024] We release the source code of VIT-VQGAN in JAX, which is used to train our audio tokenizer. Details can be found here
Install the dependencies with pip
orbax.checkpoint
may cause conflicts for dtype="bfloat16"
with JAX when using Python 3.9, but it still works with Python 3.8 (e.g., 3.8.10, which is the default in TPU VMs). This issue is possibly due to internal changes in orbax.checkpoint and pip.For a TPU:
python3 -m pip install -e '.[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_releases.html
For a GPU/CPU (note we have been using TPUs so GPU setups are not well tested):
python3 -m pip install -e '.' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Running the demo requires additional dependencies, install them with:
python3 -m pip install -e '.[demo]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_releases.html
The LLaMa tokenizer also needs to be installed, download the .model
file from https://github.com/facebookresearch/llama/tree/main?tab=readme-ov-file
and then update t5x/examples/unified_io/config.py
so LLAMA_TOKENIZER_PATH
points to the download location.
We make checkpoints in the T5X format available on S3:
To download, copy the directory recursively. For example:
aws s3 --no-sign-request cp --recursive s3://ai2-prior-uio/public/uio2-checkpoints/large-3m large-3m --exclude "state*"
They should be copied to a local disk or to google file storage. Here, the --exclude "state*"
flag excludes the optimizer state from the download, it can be removed if you want
to continue training the checkpoint from the current optimizer state.
To run the model interactively the demo notebook can be run. Make sure the demo dependencies have been installed.
Then run the demo notebook:
jupyter notebook demo.ipynb
Set FULL_CKPT_PATH
and MODEL_TYPE
in the second cell to your checkpoint and
the correct model size. Then the notebook can be used to start the demo.
The demo shows how to load the model, parameters, and do inference.
The demo will be slow the first time it is used because the inference function needs to be compiled, subsequent calls with similar inputs/outputs will be much faster.
To train and eval on entire datasets the datasets need to be registered with seqio
in seqio.TaskRegistry
. See
t5x/examples/unifiedio/data/tasks.py
for examples. See seqio
for more details on how datasets are managed by seqio.
Some datasets require running a pre-processing script before they can be used.
Make sure config.MULTITASK_TFDS_DATA_DIR
is updated to
point to the location to store the datasets.
We provided some initial datasets in t5x/examples/unifiedio/data/tasks.py
.
Our datasets are generally built one of three ways:
tensorflow_dataset
and then uploaded to the location specified in config.MULTITASK_TFDS_DATA_DIR
Datasets built in the first or second way require running a build script before they can be
used. create_data
contains the needed build scripts. For example running:
python3 create_data/tfdatasets/coco_all/build.py ~/data/tfds ~/data/vqa ~/data/coco_annotations
Will upload a tfdataset of COCO data, which allows tasks such as image_generation_coco_2017
and image_caption_coco_2017
to be used. Some datasets, such as the refexp datasets, that use
the public tensoflow catalog might have their own manual pre-processing steps as well
which will be specified on their webpage.
UnifiedIO 2 contains a large number of tasks, for this initial release we only include a subset but will add more as we test and verify additional tasks.
Pre-processing in UIO2 happens in three stages:
unified_io.data.data_utils.resize_and_pad
).
Audio segments need to be converted to mel-spectrograms, which can also be masked if working with
noised data. This stage is implemented by various preprocessing functions in unified_io.data.preprocessing
.
The demo shows how to do this for raw inputs.
To allow this stage to do different pre-processing during training and testing,
we pass a is_training
field in sequence_length dictionary to indicate
whether the dataset is being used for training or testing. modality_processing.unified_io_preprocessor
is run. This function does various task-general pre-preprocessing steps,
such as tokenizing the text, and adds empty values for missing modalities so the output dataset has a consistent set of fields.UnifiedIOFeatureConverter
is applied, this can happen
after multiple datasets have been combined into a seqio.Mixture
.
This function will make sure the output dataset has a consistent structure and is padded to have
fixed-size tensors, as is needed for jax. This dataset can now be batched and passed directly
into the loss or prediction functions of a UnifiedIO 2 model.
The padding is determined by the sequence_len dictionary.To add a dataset, register it with seqio and ensure the last pre-processor
is modality_processing.unified_io_preprocessor
. The preceding
functions should make sure the dataset has the appropriate fields for that function.
Our entire set of prompts in t5x/examples/unified_io/data/prompt_dict
,
we randomly select among these prompts during training.
We include a visualization script to show what the data looks like after post-processing:
python3 t5x/examples/unified_io/scripts/dataset_visualize.py refcoco_unc viz --override```
To get a more compact view:
python3 t5x/examples/unified_io/scripts/dataset_visualize.py refcoco_unc viz --override --gin.get_target_modalities.target_modality=[\"text\"] --gin.get_input_modalities.input_modality=[\"text\",\"image\"] --nomasks
Once a checkpoint is downloaded and a dataset is ready, training can be run using train.py.
Our training strategy largely follows T5X, which is configured through gin.
Follow the setup from https://github.com/google-research/t5x
to train on TPUs.
For example, to fine-tune the large model on refexp:
python3 t5x/train.py --gin_file=t5x/examples/unified_io/t5_1_1/large.gin --gin_file=t5x/examples/unified_io/t5_1_1/finetune/refexp.gin --gin.INITIAL_CHECKPOINT_PATH=\"/path/to/checkpoint\" --gin.MODEL_DIR=\"path/to/output_dir\" --gin.BATCH_SIZE=8
UnifiedIO 2 can be run on a subset of the supported modality, which makes training more
efficient. This can be set through the gin-configured parameters in
get_input_modalities
and get_target_modalities
. For example, refexp.gin
only turns on the image/text inputs and text outputs.
Due to jax's fixed size tensor constraint, we by default pad all inputs and targets to the
model to the maximum length supported. When training on mixtures where this is excessive,
this can be tweaked by changing the sequence_lengths used by seqio
For example, refexp,gin reduce the input and output sequence length since
refexp has little text.
We have modified train.py to use wandb, just make sure a WANDB_API_KEY
environment variable is set.
The gin configurable function utils.init_wandb
should be modified or configured
through gin to select the correct name/group/project/entity.
If the training mixture contains a mix of long and short examples, packing can make things more efficient. Packing will pack up to two examples together into a single input sequence, it can be turned on with this flag:
--gin.PackingStrategy.pack_max_len=(864, 1280)
During training, two examples will be attempted to be packed in a sequence with total input length of 864 input length and target length or 1280. A heuristic algorithm will try to find pairs of examples that fit this criterion as data is streamed to the training server, if none are found only one example will be used. If this happens too frequently it is a good idea to increase the max length. Statistics will be logged to wandb to track the packing efficiency.
Evaluation script are run using eval.py, for example:
python3 t5x/eval.py --gin_file=t5x/examples/unified_io/t5_1_1/large.gin --gin_file=t5x/examples/unified_io/t5_1_1/eval/vision_language.gin --gin.CHECKPOINT_PATH=\"large-3m\" --gin.MIXTURE_OR_TASK_NAME=\"refcoco_unc\" --gin.EVAL_OUTPUT_DIR=\"output\"
The target dataset must have metrics registered with seqio. Evaluations script can be similarly made more efficient by only using the needed modalities and choosing the sequence lengths appropriately. Note most of our official results come from collecting outputs and then running offline evaluations, the metrics here are used mostly for validation scores.
@article{lu2023uio2,
title = {Unified-IO 2: Scaling Autoregressive Multimodal Models with Vision, Language, Audio, and Action},
author = {Jiasen Lu and Christopher Clark and Sangho Lee and Zichen Zhang and Savya Khosla and Ryan Marten and Derek Hoiem and Aniruddha Kembhavi},
journal = {arXiv preprint arXiv:2312.17172},
year = {2023},
}