apple / ml-4m

4M: Massively Multimodal Masked Modeling
https://4m.epfl.ch
Apache License 2.0
1.54k stars 90 forks source link

Examples of non-generative usage (and some additional discussion) #2

Closed slowwavesleep closed 2 months ago

slowwavesleep commented 2 months ago

Hi,

I would like to kindly ask you to provide some pointers/examples on using 4M for such tasks as retrieval and classification. How do you load only the encoder, for instance?

roman-bachmann commented 2 months ago

Hi @slowwavesleep ,

Thanks for your interest! I pushed some changes to make it easier to load pre-trained 4M checkpoints as a ViT backbone. Here's an example of loading 4M-21 B weights as a ViT backbone, and adding a simplified 1000-way classification head on top of it:

import torch.nn as nn
from einops.layers.torch import Reduce
from fourm.utils import load_safetensors
from fourm.models.fm_vit import FMViT

device = 'cuda' if torch.cuda.is_available() else 'cpu'

ckpt, config = load_safetensors('./4M-21_B.safetensors')

cls_head = nn.Sequential(
    Reduce('b n d -> b d', 'mean'),
    nn.LayerNorm(config['dim'], eps=1e-6),
    nn.Linear(config['dim'], 1000),
)

fmvit = FMViT(config, output_head=cls_head).to(device)
msg = fmvit.load_state_dict(ckpt, strict=False)
print(msg) # Check that the only missing keys are from the optional output_head

# Example forward pass. Input images should be ImageNet-standardized.
logits = fmvit(torch.randn(2,3,224,224).to(device)) # Returns (B, 1000)

For retrieval, we are actually performing generation in the same manner as for all the other modalities. Given an arbitrary input conditioning, we can predict either the tokenized global DINOv2 embeddings, or the tokenized global ImageBind embeddings. These need to be decoded using their respective tokenizer decoders. Retrieval is performed in a standard manner, i.e. by retrieving the sample with the lowest cosine distance to the predicted query embedding.

Best, Roman

slowwavesleep commented 2 months ago

Thank you, @roman-bachmann!

For retrieval, we are actually performing generation in the same manner as for all the other modalities. Given an arbitrary input conditioning, we can predict either the tokenized global DINOv2 embeddings, or the tokenized global ImageBind embeddings.

Oh, I see. That makes sense now. I initially thought you were just taking the encoder's outputs, but it's actually a new paradigm (at least to me). So if you treat the embeddings as modalities, then it should be possible to convert any model's output to any other model's outputs with 4M and some additional training? Like RoBERTa to GPT-2, for instance (disregard the lack of practicality in this example).

Could you guesstimate the difference in inference time, when using an image to produce DINOv2 embeddings with DINOv2 directly VS. using let's say text input to generate DINOv2 embeddings with 4M? Is it like an order of magnitude or less?

roman-bachmann commented 2 months ago

Oh, I see. That makes sense now. I initially thought you were just taking the encoder's outputs, but it's actually a new paradigm (at least to me). So if you treat the embeddings as modalities, then it should be possible to convert any model's output to any other model's outputs with 4M and some additional training? Like RoBERTa to GPT-2, for instance (disregard the lack of practicality in this example).

Exactly! Any output from another model or any sensor can be turned into a modality if you can find a way to turn it into discrete tokens. We view modalities as any signal that describes some factor(s) of the underlying reality of the scene. This is intentionally quite vague, but it means one can include real sensory measurements (e.g. RGB), semantics and subjective expression (e.g. captions), pseudo labels from specialist networks (e.g. depth maps, bounding boxes, human poses, etc), neural network features (e.g. dense feature maps or global embeddings), etc... If you have an aligned dataset of those modalities, after training with the multimodal masked modeling objective, any subset of modalities can be mapped to any other, in a fully generative manner.

Could you guesstimate the difference in inference time, when using an image to produce DINOv2 embeddings with DINOv2 directly VS. using let's say text input to generate DINOv2 embeddings with 4M? Is it like an order of magnitude or less?

It's around the same order of magnitude if you do a single forward pass (assuming the networks are of similar sizes). We've observed that one can perform pretty decent text-to-image retrieval when using 4M to generate DINOv2 embeddings, even when predicting the 16 global tokens in a single forward pass (as opposed to generating them in multiple steps using the MaskGIT/ROAR decoding scheme). Unlike when predicting DINOv2 embeddings from images (which is a somewhat deterministic setting), text-to-DINOv2-embedding prediction is underspecified since a given caption can map to an entire distribution of DINOv2 embeddings which match that caption. It's interesting that it works quite well with a single generation step -- something to look into further in future work.

slowwavesleep commented 2 months ago

I also wanted to confirm my understanding of your example with loading 4M as ViT backbone. The loaded model is still the cross-modal one, so in principle (but not in this specific example, I guess) it can be used to classify arbitrary combinations of modalities as inputs. Does that sound about right? One usage example I can think of is something like visual NLI, where you classify text/image pairs. 4M would seem like a suitable choice for that then.

Another thing I didn't fully understand from the papers is tokenization. If you tokenize all modalities into discrete tokens, then the model should have a fixed vocabulary then? Or the tokenizers actually output token vectors that go straight into the model? I'm not sure how that works with text then.

roman-bachmann commented 2 months ago

I also wanted to confirm my understanding of your example with loading 4M as ViT backbone. The loaded model is still the cross-modal one, so in principle (but not in this specific example, I guess) it can be used to classify arbitrary combinations of modalities as inputs. Does that sound about right? One usage example I can think of is something like visual NLI, where you classify text/image pairs. 4M would seem like a suitable choice for that then.

For simplicity we made FMViT specific to RGB pixel inputs to serve as a ViT backbone, but you are right that the underlying model that is loaded is still the cross-modal one. It should be relatively simple to extend FMViT to multimodal inputs by adding back the other modality embeddings (as in the FM class).

Another thing I didn't fully understand from the papers is tokenization. If you tokenize all modalities into discrete tokens, then the model should have a fixed vocabulary then? Or the tokenizers actually output token vectors that go straight into the model? I'm not sure how that works with text then.

Each modality has a fixed vocabulary size, yes. The 4M model has learned embedding layers for each modality which transform the discrete codes into vectors. The only exceptions to that are RGB pixels and T5 embeddings, which are input-only and are fed into 4M using learned linear projections (patch-wise for RGB).

slowwavesleep commented 2 months ago

Thanks again for the discussion and awesome work, @roman-bachmann!

SecretMG commented 2 months ago

I also wanted to confirm my understanding of your example with loading 4M as ViT backbone. The loaded model is still the cross-modal one, so in principle (but not in this specific example, I guess) it can be used to classify arbitrary combinations of modalities as inputs. Does that sound about right? One usage example I can think of is something like visual NLI, where you classify text/image pairs. 4M would seem like a suitable choice for that then.

For simplicity we made FMViT specific to RGB pixel inputs to serve as a ViT backbone, but you are right that the underlying model that is loaded is still the cross-modal one. It should be relatively simple to extend FMViT to multimodal inputs by adding back the other modality embeddings (as in the FM class).

Another thing I didn't fully understand from the papers is tokenization. If you tokenize all modalities into discrete tokens, then the model should have a fixed vocabulary then? Or the tokenizers actually output token vectors that go straight into the model? I'm not sure how that works with text then.

Each modality has a fixed vocabulary size, yes. The 4M model has learned embedding layers for each modality which transform the discrete codes into vectors. The only exceptions to that are RGB pixels and T5 embeddings, which are input-only and are fed into 4M using learned linear projections (patch-wise for RGB).

However, I noticed that any modality can generate RGB pixels in Fig.2 of the original 4M-21 paper. Does that mean the RGB pixels are not input-only?

roman-bachmann commented 2 months ago

@SecretMG, in this example we are actually predicting tokenized RGB. We add pixel RGB as an input-only modality to enable 4M to be used as a ViT backbone, and tokenized RGB as inputs and outputs to enable generation. Tokenization is a lossy process, and when performing perceptual tasks we don't want to be limited to using the frozen image features from a tokenizer.

amaarora commented 2 months ago

Hello @roman-bachmann

Firstly, thank you so much for the amazing work! I just want to also expand on the following for using 4M-21 for retrieval purpose.

For retrieval, we are actually performing generation in the same manner as for all the other modalities. Given an arbitrary input conditioning, we can predict either the tokenized global DINOv2 embeddings, or the tokenized global ImageBind embeddings. These need to be decoded using their respective tokenizer decoders. Retrieval is performed in a standard manner, i.e. by retrieving the sample with the lowest cosine distance to the predicted query embedding.

To achieve functionality as shown in fig.5 of the paper (also shared below):

image

Is this what we do:

  1. Use input modalities such as Query Image + Text (or subset of 21 input modalities) to generate an image
  2. Use the generated image (from step-1) and use DinoV2 global embedding tokenizer to convert the generated image to tokens
  3. Use the DinoV2 global embedding tokenizer to also convert a number of images to tokens (this is our database that we want to do retrieval from)
  4. Finally perform a cosine similarity to get top-N retrieved results (between output of step-2 and step-3)

Does the above summary seem correct?

ofkar commented 2 months ago

Hi @amaarora ,

You don't have to predict an image as an intermediate stage. You can directly predict DINOv2 (or ImageBind) global embeddings from any subset of modalities. Then, given the global embeddings of the query and retrieval set, you can perform retrieval by comparing their cosine similarities.

So, a summary would be as follows:

  1. Load DINOv2 (or ImageBind) global embedding tokenizer.

tok_dinov2_global = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224')

  1. Use any input modality to predict DINOv2 (or ImageBind) global embeddings. This will be your query embedding.
cond_domains = [COND_DOMAIN1, COND_DOMAIN2]
target_domains = ['tok_dinov2_global']
tokens_per_target = [16]
autoregression_schemes = ['autoregressive']
  1. Obtain the DINOv2 global embeddings for the retrieval set in a similar way. If you are using RGB images as retrieval set, alternatively, first predict DINOv2 global embedding of them using the pretrained DINOv2 (or ImageBind) model, then tokenize the global embeddings using the tokenizer from Step 1.

  2. Perform cosine similarity comparison between the query and retrieval set.

amaarora commented 2 months ago

Hey @ofkar, thank you so much! This is very helpful. I am writing a blog post on this and will share a draft shortly. :)