NVlabs / RADIO

Official repository for "AM-RADIO: Reduce All Domains Into One"
Other
835 stars 35 forks source link

Replace the clip with radio #68

Open StuHude opened 5 months ago

StuHude commented 5 months ago

Congrats! What an fantastic work!

But now I am trying to replace CLIP with RADIO in the image-text task. Can RADIO be used with CLIP text encoder directly? If so, are there adaptor codes and weights? Or do I need to training the projection layer?

gheinrich commented 5 months ago

Hello, yes you can use the CLIP adaptor and the corresponding tokenizer and text encoder. There is an example on https://github.com/NVlabs/RADIO/blob/main/examples/zero_shot_imagenet.py.

mranzinger commented 5 months ago

In addition, here's a minimal pseudocode that should work:

import torch
import torch.nn.functional as F

model = torch.hub.load('NVlabs/RADIO', 'radio_model', version='radio_v2', adaptor_names='clip')
output = model(images)  # Inputs should have values between 0 and 1
bb_summary, bb_features = output['backbone']
clip_summary, clip_features = output['clip']  # These are the DFN CLIP embeddings

# To get the text embeddings
clip_adaptor = model.adaptors['clip']
tokens = clip_adaptor.tokenizer(['foo', 'bar'])
clip_text_embeddings = clip_adaptor.encode_text(tokens)

# B x B compatibility matrix from each image embedding to each text embedding (e.g. CLIP objective)
alignment = F.normalize(clip_summary, dim=1) @ F.normalize(clip_text_embeddings.T, dim=0)
StuHude commented 5 months ago

Thank you very much for your answer! In addition, I would like to ask if you can release the model structure of RADIO, I hope to get the output of each layer in the model. If possible, it will be of great help to me. Thank you very much!

gheinrich commented 5 months ago

Hello, the model architecture is defined in https://github.com/NVlabs/RADIO/blob/main/radio/radio_model.py however the bulk of the instantiation is performed by the TIMM library, since we use a mostly standard VisionTransformer model.

We are contemplating adding an API to fetch intermediate activations in the future. In the meantime, assuming you are using RADIO (not E-RADIO), this can be achieved be re-writing the _forward_cpe method in https://github.com/NVlabs/RADIO/blob/main/radio/enable_cpe_support.py.

For example, you might write it as:

    def forward_features(self, x):
        """Return features from the model."""
        features = []

        if isinstance(self.model, VisionTransformer):
            x = self.model.patch_generator(x)

            for blk in self.model.blocks:
                x = blk(x)               
                features.append(self.model.norm(x))

        else:
            raise ValueError("Only VisionTransformer is supported here")

        return features
mranzinger commented 4 months ago

Btw, @gheinrich has made support for intermediate activations part of the official API: https://github.com/NVlabs/RADIO?tab=readme-ov-file#intermediate-layer-activations

cspearl commented 1 week ago

Hello, the model architecture is defined in https://github.com/NVlabs/RADIO/blob/main/radio/radio_model.py however the bulk of the instantiation is performed by the TIMM library, since we use a mostly standard VisionTransformer model.

We are contemplating adding an API to fetch intermediate activations in the future. In the meantime, assuming you are using RADIO (not E-RADIO), this can be achieved be re-writing the _forward_cpe method in https://github.com/NVlabs/RADIO/blob/main/radio/enable_cpe_support.py.

For example, you might write it as:

    def forward_features(self, x):
        """Return features from the model."""
        features = []

        if isinstance(self.model, VisionTransformer):
            x = self.model.patch_generator(x)

            for blk in self.model.blocks:
                x = blk(x)               
                features.append(self.model.norm(x))

        else:
            raise ValueError("Only VisionTransformer is supported here")

        return features

Hey, as you mentioned the ViT model has been instantiated by the timm library, but I wanted to play around with the layers, not just get thr intermediate layer outputs but rather modify them to see the effect on the overall model. How should I go about doing that? Do I change the _forward_cpe method as you mentioned or some other lines in radio_model.py file?

mranzinger commented 1 week ago

Yeah, if you're wanting to start replacing parts of the model architecture, then you could even replace the modules within the self.model.blocks list. TIMM fortunately has the vision transformer definition pretty self contained, so you could pull it from https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py and then start tweaking as you see fit.

cspearl commented 1 week ago

Thanks, I'll try doing that. Just out of curiosity, are the adaptors mentioned specifically the only ones available, or is there a way to train a different model adaptor using the radio features using your codebase? If so, where would I need to add or modify parts of the code to implement that?

mranzinger commented 1 week ago

At the moment, we don't have a public release of the training code, so officially, you could only use the existing adaptors that come with a given model. That said, I'm guessing that you could get decently good at matching some new model by adding a new MLP2, freezing the radio backbone, and training to match the new teacher. Since the backbone was frozen, you wouldn't be dealing with catastrophic forgetting.