NVlabs / RADIO

Official repository for "AM-RADIO: Reduce All Domains Into One"
Other
501 stars 16 forks source link

Replace the clip with radio #68

Open StuHude opened 3 days ago

StuHude commented 3 days ago

Congrats! What an fantastic work!

But now I am trying to replace CLIP with RADIO in the image-text task. Can RADIO be used with CLIP text encoder directly? If so, are there adaptor codes and weights? Or do I need to training the projection layer?

gheinrich commented 2 days ago

Hello, yes you can use the CLIP adaptor and the corresponding tokenizer and text encoder. There is an example on https://github.com/NVlabs/RADIO/blob/main/examples/zero_shot_imagenet.py.

mranzinger commented 2 days ago

In addition, here's a minimal pseudocode that should work:

import torch
import torch.nn.functional as F

model = torch.hub.load('NVlabs/RADIO', 'radio_model', version='radio_v2', adaptor_names='clip')
output = model(images)  # Inputs should have values between 0 and 1
bb_summary, bb_features = output['backbone']
clip_summary, clip_features = output['clip']  # These are the DFN CLIP embeddings

# To get the text embeddings
clip_adaptor = model.adaptors['clip']
tokens = clip_adaptor.tokenizer(['foo', 'bar'])
clip_text_embeddings = clip_adaptor.encode_text(tokens)

# B x B compatibility matrix from each image embedding to each text embedding (e.g. CLIP objective)
alignment = F.normalize(clip_summary, dim=1) @ F.normalize(clip_text_embeddings.T, dim=0)
StuHude commented 1 day ago

Thank you very much for your answer! In addition, I would like to ask if you can release the model structure of RADIO, I hope to get the output of each layer in the model. If possible, it will be of great help to me. Thank you very much!

gheinrich commented 1 day ago

Hello, the model architecture is defined in https://github.com/NVlabs/RADIO/blob/main/radio/radio_model.py however the bulk of the instantiation is performed by the TIMM library, since we use a mostly standard VisionTransformer model.

We are contemplating adding an API to fetch intermediate activations in the future. In the meantime, assuming you are using RADIO (not E-RADIO), this can be achieved be re-writing the _forward_cpe method in https://github.com/NVlabs/RADIO/blob/main/radio/enable_cpe_support.py.

For example, you might write it as:

    def forward_features(self, x):
        """Return features from the model."""
        features = []

        if isinstance(self.model, VisionTransformer):
            x = self.model.patch_generator(x)

            for blk in self.model.blocks:
                x = blk(x)               
                features.append(self.model.norm(x))

        else:
            raise ValueError("Only VisionTransformer is supported here")

        return features