NVlabs / RADIO

Official repository for "AM-RADIO: Reduce All Domains Into One"
Other
607 stars 23 forks source link

Use RADIOV2 as VLM's vision encoder. #60

Closed echo840 closed 3 months ago

echo840 commented 3 months ago

Hello, thank you for your great work! We are currently exploring the utilization of radio as a vision encoder for vision language models. In our specific setup, we employ SigClip and RADIOV2 as the vision encoder, while Phi2 serves as the language model. The obtained results are as follows: image

They use the same data and configuration, the only difference is the vision encoder. Is it normal to observe worse performance when using a RADIOv2 compared to using SigClip?

# Feature extract
class RadioVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.vision_tower_name = vision_tower

        if not delay_load:
            self.load_model()
        else:
            self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name,trust_remote_code=True)

    def load_model(self):
        self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
        self.image_processor.do_resize = True
        self.image_processor.crop_size = self.image_processor.size
        self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, trust_remote_code=True)
        self.vision_tower.requires_grad_(False)

        self.is_loaded = True

    @torch.no_grad()
    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                _ , image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(self.dtype)
                image_features.append(image_feature)
        else:
            _ , image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(self.dtype)
        return image_features
#image process. we resize the image to 432, which is RADIO's preferred_resolution.
 if self.data_args.image_aspect_ratio == 'pad':
      def expand2square(pil_img, background_color):
          width, height = pil_img.size
          if width == height:
              return pil_img
          elif width > height:
              result = Image.new(pil_img.mode, (width, width), background_color)
              result.paste(pil_img, (0, (width - height) // 2))
              return result
          else:
              result = Image.new(pil_img.mode, (height, height), background_color)
              result.paste(pil_img, ((height - width) // 2, 0))
              return result
      image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
      image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
  elif:
      width, height = image.size
      max_size = max(width,height)
      image = image.resize((max_size,max_size))
      image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 

Could you give me some suggestions?

gheinrich commented 3 months ago

Hello, in the experiments we published in our paper, we used an image pre-processor that resizes the longest edge to 432, adjusting the shortest edge to keep the original image aspect ratio, followed by a crop along the shortest edge to the nearest multiple of the patch size. This should be mostly equivalent to expand2square followed by a resize to 432x432, only without the padding along the shortest dimension. This requires support for variable-size, non-square images.

Are you using image_aspect_ratio == 'pad' as I suspect otherwise we might end up cropping actual pixels on the edges along the longest edge?

echo840 commented 3 months ago

Thank you for your response! Yes, during finetuning, we used image_aspect_ratio == 'pad'. I'm now trying the experiment according to your instructions. Thank you very much!

echo840 commented 3 months ago

Hello, RADIOV2 is still lower than SigClip. I would like to know if I have missed any operations in the feature extraction code below. Do I need to extract features from the second-to-last layer from vision tower like LLAVA? Or if I have overlooked the normalization operation? Or do I need to add the summary token?

# Feature extract
class RadioVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.vision_tower_name = vision_tower

        if not delay_load:
            self.load_model()
        else:
            self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name,trust_remote_code=True)

    def load_model(self):
        self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
        self.image_processor.do_resize = True
        self.image_processor.crop_size = self.image_processor.size
        self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, trust_remote_code=True)
        self.vision_tower.requires_grad_(False)

        self.is_loaded = True

    @torch.no_grad()
    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                _ , image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(self.dtype)
                image_features.append(image_feature)
        else:
            _ , image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(self.dtype)
        return image_features
gheinrich commented 3 months ago

Hello, I have not worked with the HuggingFace model in LLaVA however equivalently you should be able to use the TorchHub model. In my LLaVA integration I used standard normalization instead of the built-in input conditioner (i.e. I make a call to vision_tower.make_preprocessor_external()).

This is my code (pardon the lack of untidiness):

from argparse import Namespace
import os
import torch
import torch.nn as nn
from typing import Any, Dict
import warnings

from transformers import CLIPVisionConfig
from transformers import CLIPImageProcessor, SamImageProcessor
from PIL import Image
import numpy as np

class RADIOVisionTower(nn.Module):
    """
    Vision Tower for the RADIO model.

    Args:
        vision_tower (str): Vision tower name. This is passed on
            the command line with the `--vision_tower` argument.
            The string is expected in the pattern of:
            `radio:<image_size>:<checkpoint_or_version>:<extra_config>`.
            Where <extra_config> is a comma-separated list of key=value pairs.
            <image_size> is the image resolution.
            <checkpoint> is a TorchHub version or path to a checkpoint.
        args (Namespace): Arguments.
        delay_load (bool): Delay loading the model.
    """
    def __init__(self, vision_tower, args, delay_load=False):
        """Initialization Routine."""

        super().__init__()

        self.vision_tower_name = vision_tower[len("radio:"):]
        config_items = self.vision_tower_name.split(":")
        self.image_sizes = [int(x) for x in config_items[0].split(",")]
        if len(self.image_sizes) == 0:
            raise ValueError("Expected more than zero images sizes!")
        self.image_size = self.image_sizes[0]
        self.do_center_crop = args.mm_im_crop

        self.vision_tower_checkpoint = config_items[1]

        extra_config = {}
        if len(config_items) > 2:
            # Parse extra config items. These are provided as a comma-separated list
            # of key=value pairs.
            extra_config_items = config_items[2].split(",")

            for item in extra_config_items:
                key, value = item.split("=")
                extra_config[key] = value

        self.adaptor_name = extra_config.get("adaptor", "backbone")
        self.fuse_adaptor_with_backbone = eval(extra_config.get("fuse_adaptor_with_backbone", "False"))
        self.skip_layer_norm = eval(extra_config.get("skip_layer_norm", "False"))

        self.is_loaded = False

        if not delay_load:
            self.load_model()
        else:
            # FIXME: This is a hack to avoid having to load the config from the checkpoint.
            hidden_size = self.get_hidden_size()
            patch_size = 16

            self.cfg_only = CLIPVisionConfig(
                **{

                    "hidden_size": hidden_size,
                    "image_size": self.image_size,
                    "model_type": "radio_vision_model",
                    "num_attention_heads": None,
                    "num_channels": 3,
                    "num_hidden_layers": None,
                    "patch_size": patch_size,
                }
            )

    def get_hidden_size(self):
        if self.adaptor_name == "openai_clip":
            hidden_size = 1024
        elif self.adaptor_name == "clip":
            hidden_size = 1280
        elif self.adaptor_name == "rtx-translate":
            hidden_size = 2048
        elif self.adaptor_name == "backbone":
            hidden_size = 1280
        else:
            raise ValueError(f"Unknown adaptor name: {self.adaptor_name}")

        if self.fuse_adaptor_with_backbone:
            hidden_size += 1280

        return hidden_size

    @property
    def hidden_size(self):
        return self.get_hidden_size()

    def load_model(self):

        crop_size={'height': self.image_size, 'width': self.image_size}

        if self.do_center_crop:
            self.image_processor = CLIPImageProcessor(
                size={"shortest_edge": self.image_size},
                crop_size=crop_size,
                do_center_crop=self.do_center_crop,
                do_normalize=True,
            )
        else:
            self.image_processor = SamImageProcessor(
                    size={"longest_edge": self.image_size},
                    pad_size={'height': self.image_size, 'width': self.image_size},
                    do_pad=False,
                    do_normalize=True,
            )
            # Add a crop_size attribute to the image processor, since the
            # train.py script needs this to generate fake images of zeros
            # with the right size, when the sample does not have an
            # associated image.
            self.image_processor.crop_size = crop_size

        # For compatibility with CLIP Image Processor: the data loader uses width/height to
        # create dummy blank images for samples that don't have an image.
        self.image_processor.crop_size = {"width": self.image_size, "height": self.image_size}

        checkpoint_path_or_version = self.vision_tower_checkpoint

        # NOTE: do a lazy import of Timm to avoid issues with
        # DeepSpeed's ZeRO-3.
        from timm.models.vision_transformer import VisionTransformer

        self.vision_tower = torch.hub.load('NVlabs/RADIO',
                                           'radio_model',
                                           version=checkpoint_path_or_version,
                                           progress=True,
                                           adaptor_names=self.adaptor_name if self.adaptor_name != "backbone" else None)

        if isinstance(self.vision_tower.model, VisionTransformer):
            hidden_size = self.vision_tower.model.embed_dim
        else:
            raise ValueError(f"Unknown model type: {self.vision_tower}")

        # Override hidden size for OpenAI CLIP.
        hidden_size = self.get_hidden_size()

        if hasattr(self.vision_tower.model, "patch_generator"):
            patch_gen = self.vision_tower.model.patch_generator
            # Cropped Positional Embedding (CPE) case.
            patch_size = patch_gen.patch_size
        else:
            # Standard ViT case.
            patch_size = self.vision_tower.model.patch_embed.patch_size[0]

        self.vision_tower.config = CLIPVisionConfig(
                **{
                    "hidden_size": hidden_size,
                    "image_size": self.image_size,
                    "model_type": "radio_vision_model",
                    "num_attention_heads": None,
                    "num_channels": 3,
                    "num_hidden_layers": None,
                    "patch_size": patch_size,
                }
            )

        self.vision_tower.make_preprocessor_external()
        self.vision_tower.eval()
        self.vision_tower.requires_grad_(False)

        self.is_loaded = True
        self._to_dtype = None

        if self.skip_layer_norm:
            self.vision_tower.model.norm = torch.nn.Identity()

    def to(self, *args, **kwargs):
        # Prevent casting the RADIO model's weights
        kwargs = dict(kwargs)
        self._to_dtype = kwargs.pop('dtype', None)
        super().to(*args, **kwargs)
        pass

    def train(self, mode=True):
        """Intercept call."""
        # Drop a warning if mode is True.
        if mode:
            warnings.warn("RADIOEncoder is always in eval mode.")
        pass

    @torch.no_grad()
    def get_features(self, x: torch.Tensor):
        output = self.vision_tower(x)
        if isinstance(output, dict):
            _, features = output[self.adaptor_name]
            if self.fuse_adaptor_with_backbone:
                _, backbone_features = output["backbone"]
                features = torch.cat([features, backbone_features], dim=2)
        else:
            _, features = output
        return features

    @torch.no_grad()
    def forward(self, images: torch.Tensor):
        """Main forward pass."""
        input_shape = images.shape

        x = images

        # Add a batch dimension if necessary.
        if len(input_shape) == 3:
            x = x.unsqueeze(0)

        # Convert the input to the model's dtype (we assume
        # that the model only has one dtype for all parameters).
        param0 = next(self.vision_tower.parameters())
        x = x.to(dtype=param0.dtype, device=param0.device)

        patch_size = self.vision_tower.config.patch_size

        if self.do_center_crop:
            # Crop the input to a multiple of patch size.
            _, _, H, W = x.shape

            H = H - (H % patch_size)
            W = W - (W % patch_size)

            x = x[:, :, :H, :W]
        else:
            # Pad to nearest multiple of patch size
            _, _, H, W = x.shape
            H = H + (patch_size - (H % patch_size)) % patch_size
            W = W + (patch_size - (W % patch_size)) % patch_size
            x = nn.functional.pad(x, (0, W - x.shape[3], 0, H - x.shape[2]), mode="constant", value=0)

        features = self.get_features(x) # B, T, C

        B, _, H, W = x.shape
        _, _, C = features.shape

        # Remove the batch dimension if we added it.
        if len(input_shape) == 3:
            features = features.squeeze(0)

        # Cast back to the input's dtype.
        features = features.to(images.dtype)

        assert features.shape[-1] == self.get_hidden_size()

        return features
echo840 commented 3 months ago

Thank you! I‘m also curious about the setting of "extra_config" and "config_items ". Is the setting for the following parameters is true or false?

self.adaptor_name = extra_config.get("adaptor", "backbone")
self.fuse_adaptor_with_backbone = eval(extra_config.get("fuse_adaptor_with_backbone", "False"))
self.skip_layer_norm = eval(extra_config.get("skip_layer_norm", "False"))
gheinrich commented 3 months ago

Hi, in my standard configuration the adaptor is backbone, and fuse_adaptor_with_backbone and skip_layer_norm are both False.

echo840 commented 3 months ago

Hi, in my standard configuration the adaptor is backbone, and fuse_adaptor_with_backbone and skip_layer_norm are both False.

Thank you for your prompt response and your great work!

gheinrich commented 3 months ago

Hello, have you been able to get RADIO to perform well in your VLM setup?

echo840 commented 3 months ago

I'm sorry, to be honest, I can't achieve better results than Sigclip under the same settings. Sigclip has a resolution of 384, while Radio's resolution is dynamic (with a maximum size set to 1280). To save time, we use qwen2 0.5b as LLM. And we also add some OCR data such as docvqa and textvqa. However, the experiments are at the same setting.

image

Hello, have you been able to get RADIO to perform well in your VLM setup?

gheinrich commented 1 month ago

Hello, our RADIOv2.5 very much improves VLM metrics, see the release notes at the root of this repo. Would you like to try it?

echo840 commented 1 month ago

Hello, the results of RADIOv2.5 are indeed quite impressive. I'm curious if RADIOv2.5 supports dynamic resolution. Given that different tasks and images may require different resolution settings, recent VLMs, such as the one detailed in this paper, have adopted the strategy of splitting the original image to ensure the input image resolution is close to the original, achieving very good results. Since RADIOv2.5 naturally possesses the ability to support arbitrary resolutions, I'm wondering if it's possible to use only a single instance of RADIOv2.5 to support dynamic resolution. Additionally, does RADIOv2.5 have a maximum resolution limit of 768? If it can support larger resolutions, using RADIOv2.5 as a visual encoder might yield better performance on document images (DocVQA). Thank you for your great work!

mranzinger commented 1 month ago

Yes, the RADIOv2.5 family of models supports dynamic resolution. Indeed, using RADIO it wouldn't be necessary to tile. It also isn't necessary that the image is square. The only requirement is that each dimension is a multiple of 16.

If you check out the tech report, you'll see that the model does well all the way up to 2048px. It can go even higher, although we haven't spent much time assessing it.

echo840 commented 1 month ago

Yes, the RADIOv2.5 family of models supports dynamic resolution. Indeed, using RADIO it wouldn't be necessary to tile. It also isn't necessary that the image is square. The only requirement is that each dimension is a multiple of 16.

If you check out the tech report, you'll see that the model does well all the way up to 2048px. It can go even higher, although we haven't spent much time assessing it. image

Thank you for your reply. I tried the dynamic resolution setting on radio2.1, but found the performance to be poor. I'm unsure if there have been improvements in radio2.5.

mranzinger commented 1 month ago

Yes, fixing "mode switching" is a major thing we fixed in the latest release. That was probably the primary reason that you were seeing weird results with dynamic resolution and RADIOv2.1. Definitely give it a try if you have the chance. The gist is that RADIOv2.1 and below were behaving differently at resolutions below ~720px versus above. The representations would dramatically change around that threshold.

This no longer happens with the new models, and we demonstrate how increasing the resolution from 432 up to 768 dramatically improves our LLaVA metrics. Depending on the language model, you could go even higher for even better results (particularly for OCR tasks).

echo840 commented 1 month ago

Thank you for your response, I am very willing to give it a try on radio2.5.

zhudongwork commented 3 weeks ago

Is there any good news?