TRI-ML / prismatic-vlms

A flexible and efficient codebase for training visually-conditioned language models (VLMs)
MIT License
425 stars 194 forks source link

Inconsistent API for Vision Backbones? #14

Closed RylanSchaeffer closed 5 months ago

RylanSchaeffer commented 5 months ago

The base VisionBackbone defined a forward method that accepts pixel_values as a Pytorch tensor:

https://github.com/TRI-ML/prismatic-vlms/blob/main/prismatic/models/backbones/vision/base_vision.py#L70-L73

    @abstractmethod
    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
        raise NotImplementedError

However, some of the derived DINO+other models have a different interface:

https://github.com/TRI-ML/prismatic-vlms/blob/main/prismatic/models/backbones/vision/dinoclip_vit.py#L124-L127


    def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
        dino_patches = self.dino_featurizer(pixel_values["dino"])
        clip_patches = self.clip_featurizer(pixel_values["clip"])
  1. Is this a mistake or intentional?
  2. If intentional, how should pixel_values for dino differ from pixel_values for other models?
siddk commented 5 months ago

Sorry about this confusion here @RylanSchaeffer -- I need to go and update some of the base class type signatures.

For the fused backbones (e.g., DINO + CLIP, DINO + SigLIP), because the underlying image transforms are different (e.g., different pixel normalization values for DINO vs. CLIP), we pass a dictionary with str keys and Tensor values (the corresponding pixel_values).

Hopefully this makes sense!

RylanSchaeffer commented 5 months ago

@siddk , sorry, I should have been more clear. Everything is conceptually clear. What I was trying to communicate is:

  1. Suggestion: From an API design perspective, I think the pixel normalization should be handled under the hood. The user should just need to provide the raw pixel values, and each model should apply its own appropriate transformations. This minimizes the user's chances of making mistakes and also makes the API consistent across VLMs since only the image needs to be provided :slightly_smiling_face:

  2. Question: Independent of the above suggestion, can you please post code demonstrating correct image normalization for each VLM? I want to make sure I'm doing normalization correctly.

Thank you!

siddk commented 5 months ago

Ah got it -- is the issue that the PrismaticVLM.generate() function isn't expressive enough for different use cases? This just takes a PIL.Image as input and handles all image-specific normalization under the hood.

For more control over the image transform, each VisionBackbone also exposes a image_transform field that is just a function that takes an Image, and does the appropriate normalization / resizing / etc. -- see here

RylanSchaeffer commented 5 months ago

is the issue that the PrismaticVLM.generate() function isn't expressive enough for different use cases?

Oh no, .generate() works well, but I'm working on an adversarial robustness project, meaning .generate() isn't the right method for my purposes. Rather, I want to use .forward(), which expects pixel_values to be a tensor: https://github.com/TRI-ML/prismatic-vlms/blob/main/prismatic/models/vlms/prismatic.py#L256

So when I do the following:

    def compute_loss(
        self,
        image: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
    ) -> torch.Tensor:
        image = self.resizer(image)
        images = image.to(self.device_str).repeat(len(input_ids), 1, 1, 1)
        images_pixel_values = normalize_images(images).to(self.device_str)

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            pixel_values=images_pixel_values.to(self.device_str),
        )
        return outputs.loss

I receive the following error:

34 Traceback (most recent call last):
35   File "/lfs/ampere1/0/rschaef/PerezAstraFellowship-Universal-VLM-Jailbreak/evaluate_jailbreak_attacks_against_vlms.py", line 146, in <module>
36     evaluate_vlm_adversarial_examples()
37   File "/lfs/ampere1/0/rschaef/PerezAstraFellowship-Universal-VLM-Jailbreak/evaluate_jailbreak_attacks_against_vlms.py", line 109, in evaluate_vlm_adversarial_examples
38     model_evaluation_results = attacker.evaluate_jailbreak_against_vlms_and_log(
39                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
40   File "/lfs/ampere1/0/rschaef/miniconda3/envs/universal_vlm_jailbreak_env/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
41     return func(*args, **kwargs)
42            ^^^^^^^^^^^^^^^^^^^^^
43   File "/lfs/ampere1/0/rschaef/PerezAstraFellowship-Universal-VLM-Jailbreak/src/attacks/base.py", line 93, in evaluate_jailbreak_against_vlms_and_log
44     batch_losses_per_model = vlm_ensemble.compute_loss(
45                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
46   File "/lfs/ampere1/0/rschaef/PerezAstraFellowship-Universal-VLM-Jailbreak/src/models/ensemble.py", line 124, in compute_loss
47     loss = model_wrapper.compute_loss(
48            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
49   File "/lfs/ampere1/0/rschaef/PerezAstraFellowship-Universal-VLM-Jailbreak/src/models/prismatic.py", line 107, in compute_loss
50     outputs = self.model(
51               ^^^^^^^^^^^
52   File "/lfs/ampere1/0/rschaef/miniconda3/envs/universal_vlm_jailbreak_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
53     return self._call_impl(*args, **kwargs)
54            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
55   File "/lfs/ampere1/0/rschaef/miniconda3/envs/universal_vlm_jailbreak_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
56     return forward_call(*args, **kwargs)
57            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
58   File "/lfs/ampere1/0/rschaef/PerezAstraFellowship-Universal-VLM-Jailbreak/submodules/prismatic-vlms/prismatic/models/vlms/prismatic.py", line 312, in forward
59     patch_features = self.vision_backbone(pixel_values[multimodal_indices])
60                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
61   File "/lfs/ampere1/0/rschaef/miniconda3/envs/universal_vlm_jailbreak_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
62     return self._call_impl(*args, **kwargs)
63            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
64   File "/lfs/ampere1/0/rschaef/miniconda3/envs/universal_vlm_jailbreak_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
65     return forward_call(*args, **kwargs)
66            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
67   File "/lfs/ampere1/0/rschaef/PerezAstraFellowship-Universal-VLM-Jailbreak/submodules/prismatic-vlms/prismatic/models/backbones/vision/dinosiglip_vit.py", line 139, in forward
68     dino_patches = self.dino_featurizer(pixel_values["dino"])
69                                         ~~~~~~~~~~~~^^^^^^^^
70 TypeError: new(): invalid data type 'str'
RylanSchaeffer commented 5 months ago

I just tried using the vision backbone's .image_transform() but this requires an input numpy array or PIL image, not a pytorch tensor:

self.model.vision_backbone.image_transform(images)

Error:

TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

Is the below code a viable workaround? Or would you recommend something different?

        # Remove "ToTensor()" from the default transforms.
        self.model.vision_backbone.image_transform = torchvision.transforms.Compose(
            [
                t
                for t in self.model.vision_backbone.image_transform.transforms
                if not isinstance(t, torchvision.transforms.ToTensor)
            ]
        )

And then:

    def compute_loss(
        self,
        image: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
    ) -> torch.Tensor:
        images = image.to(self.device_str).repeat(len(input_ids), 1, 1, 1)
        images_pixel_values = self.model.vision_backbone.image_transform(images).to(
            self.device_str
        )

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            pixel_values=images_pixel_values.to(self.device_str),
        )
        return outputs.loss
RylanSchaeffer commented 5 months ago

The above code throws an error:

  File "/lfs/ampere1/0/rschaef/miniconda3/envs/universal_vlm_jailbreak_env/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Input type (CUDABFloat16Type) and weight type (torch.cuda.FloatTensor) should be the same

What is the right way to use reduced precision, e.g., bfloat16?

RylanSchaeffer commented 5 months ago

Ok here is my newest solution that appears to work. Is there a simpler way of doing this?

    def create_images_transform_fn(self, model_str: str) -> Callable:
        if "dinosiglip" in model_str:
            # Convert to float32, then remove the ToTensor transform because that is applicable to PIL Images.
            dino_transforms = torchvision.transforms.Compose(
                [torchvision.transforms.ConvertImageDtype(torch.float32)]
                + [
                    t
                    for t in self.model.vision_backbone.image_transform.dino_image_transform.transforms
                    if not isinstance(t, torchvision.transforms.ToTensor)
                ]
            )

            siglip_transforms = torchvision.transforms.Compose(
                [torchvision.transforms.ConvertImageDtype(torch.float32)]
                + [
                    t
                    for t in self.model.vision_backbone.image_transform.siglip_image_transform.transforms
                    if not isinstance(t, torchvision.transforms.ToTensor)
                ]
            )

            def images_transform_fn(images: torch.Tensor) -> Dict[str, torch.Tensor]:
                transformed_images = {
                    "dino": dino_transforms(images).to(self.device_str),
                    "siglip": siglip_transforms(images).to(self.device_str),
                }
                return transformed_images

        else:
            # Convert to float32, then remove the ToTensor transform because that is applicable to PIL Images.
            transforms = torchvision.transforms.Compose(
                [torchvision.transforms.ConvertImageDtype(torch.float32)]
                + [
                    t
                    for t in self.model.vision_backbone.image_transform.transforms
                    if not isinstance(t, torchvision.transforms.ToTensor)
                ]
            )

            def images_transform_fn(images: torch.Tensor) -> torch.Tensor:
                transformed_images = transforms(images).to(self.device_str)
                return transformed_images

        return images_transform_fn

    def compute_loss(
        self,
        image: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
    ) -> torch.Tensor:
        images = image.to(self.device_str).repeat(len(input_ids), 1, 1, 1)
        transformed_images: Union[
            torch.Tensor, Dict[str, torch.Tensor]
        ] = self.images_transform_fn(images)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            pixel_values=transformed_images,
        )
        return outputs
siddk commented 5 months ago

I just tried using the vision backbone's .image_transform() but this requires an input numpy array or PIL image, not a pytorch tensor:

So one thing I'm missing here is exactly where your PyTorch Tensor is coming from; is image processed in any way, or is it just the raw pixel values (as uint8 in [0, 255])?

RylanSchaeffer commented 5 months ago

image is a randomly initialized tensor with shape (3, 512, 512) (I might have the dimensions transposed - would need to check) and values in [0, 1) that I then optimize to jailbreak the VLM.

RylanSchaeffer commented 5 months ago

My understanding is that torchvision.transforms.ToTensor converts the PIL Image of type uint8 with values in [0, 256) to a tensor with values in [0, 1), so since I'm removing torchvision.transforms.ToTensor, I think want my values to be in the range [0, 1). Is that correct?

siddk commented 5 months ago

Got it; in that case you can just cut off the ToTensor() transform like in your first example. You don't want to use ConvertImageDtype (it does some additional rescaling under the hood which you probably don't want).

siddk commented 5 months ago

One weird thing -- I'm not sure torchvision transforms are differentiable (e.g., the normalize call may not propagate gradients).

Feels like the "right" thing to do is just skip the transform step completely, and just initialize images subject to normalization parameters (sample values per channel based on the mean/std of the transform), and feed the resulting Tensor to forward() directly.

RylanSchaeffer commented 5 months ago

Ok thank you!! I really appreciate the help :)