SunzeY / AlphaCLIP

[CVPR 2024] Alpha-CLIP: A CLIP Model Focusing on Wherever You Want
https://aleafy.github.io/alpha-clip
Apache License 2.0
640 stars 38 forks source link

Could you release the code of integrating blip2 with alpha clip? #27

Closed Akshay1-6180 closed 6 months ago

Akshay1-6180 commented 7 months ago
Screenshot 2024-02-02 at 7 00 07 PM

Amazing paper , had a pleasant experience reading it. So had a few doubts for using alpha clip with BLIP 2 are u using the frozen alpha clip model as the image encoder and then sending the mask and the image and then doing blip2 training or as seen in the image u are somehow integrating mask with qformer and image with qformer , i didnt quite understand it. Moreover did u compare zero shot image captioning ablites similar to BLIP2 with the llava bench ? Thanks

SunzeY commented 7 months ago

Thank you for your attention! The image you show is the ablation study of showing other simple possible masking approaches. Our Alpha-CLIP with BLIP-2 only involves input mask as alpha-map into BLIP2 without training process. Due to BLIP-2 use EVA-CLIP as its main-stream model. only a side version using CLIP-L14 and flant5-xl. So we didn't use BLIP-2 to compare zero shot image captioning ability as it will require retrain Alpha-CLIP with EVA-CLIP text encoder. here is a code of using BLIP-2(a bit messy as we use it to test on different image)

import torch
import collections
from PIL import Image
from lavis.models import load_model_and_preprocess
from torchvision import transforms
import types
import cv2
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import json
alpha_v = None

PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
MASK_FILL = [int(255 * c) for c in PIXEL_MEAN]

mask_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((224, 224)),
    transforms.Normalize(0.5, 0.26)
])

def crop_center(img, croph, cropw):
    h, w = img.shape[:2]
    starth = h//2 - (croph//2)
    startw = w//2 - (cropw//2)    
    return img[starth:starth+croph, startw:startw+cropw, :]

wb_mask = np.array(Image.open('ellite_test/test_datasets/3_bg.png'))
binary_mask = (wb_mask[:, :, 0] == 255)

alpha = mask_transform((binary_mask * 255).astype(np.uint8))
alpha = alpha.cuda().unsqueeze(dim=0)
mask_pool = torch.nn.AvgPool2d(14, stride=14)
import copy
alpha_bet = [[255, 1.923], [210, 1.2443], [80, -0.7164], [0, -1.923]]

def save_alpha(wb_mask):
    global alpha, alpha_v
    binary_mask = (wb_mask[:, :, 0] == 255)

    alpha = mask_transform((binary_mask * alpha_v).astype(np.uint8))
    alpha = alpha.cuda().unsqueeze(dim=0)

def rewrited_forward(self, x: torch.Tensor):
    global alpha
    if alpha is None: # better 
        print(f"[Warning] in {type(self)} forward: no alpha input when use alpha CLIP, alpha is expected!")
        alpha = torch.ones_like((x[:, [0], :, :])) * -1.9231

    x = self.conv1(x)  # shape = [*, width, grid, grid]
    x = x + self.conv1_alpha(alpha)
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
    x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
    x = x + self.positional_embedding.to(x.dtype)
    x = self.ln_pre(x)

    x = x.permute(1, 0, 2)  # NLD -> LND
    x = self.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD

    return x

if __name__ == "__main__":
    # setup device to use
    device = torch.device("cuda:2") if torch.cuda.is_available() else "cpu"
    torch.cuda.set_device("cuda:2")
    # loads InstructBLIP model

    model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xl_vitL", is_eval=True, device=device)
    model.visual_encoder

    state_dict = torch.load('clip_l14_grit+mim_fultune_6xe.pth')
    converted_dict = collections.OrderedDict()
    for k, v in state_dict.items():
        # if "visual" in k:
        if 'in_proj.weight' in k:
            converted_dict[k.replace('in_proj.weight', 'in_proj_weight')] = v
        elif 'in_proj.bias' in k:
            converted_dict[k.replace('in_proj.bias', 'in_proj_bias')] = v
        else:
            converted_dict[k] = v
        # else:
        #     converted_dict[k] = v
    # check why only use 23 layers
    model.visual_encoder.conv1_alpha = torch.nn.Conv2d(in_channels=1,
                                                        out_channels=model.visual_encoder.conv1.out_channels, 
                                                        kernel_size=model.visual_encoder.conv1.kernel_size, 
                                                        stride=model.visual_encoder.conv1.stride, 
                                                        bias=False)
    model.visual_encoder.forward = types.MethodType(rewrited_forward, model.visual_encoder)
    model.visual_encoder.load_state_dict(converted_dict, strict=False)
    model.visual_encoder = model.visual_encoder.half().cuda()
    case_caption = ["00003859", "00004485", "00004527", "00010894", "00012813", "00015404", "00019723", "00023905", "00024348", "00024866",
                    "00031427", "00031695", "00036038", "00036786", "00044033"]
    case_diff = ["000088630", "00010926", "00011337", "00022030", "00024385", "00026003", "00027957", "00031394", "00031533", "00047550"]
    todo = case_caption + case_diff
    res = []
    for alpha_value in alpha_bet:
        alpha_v = alpha_value[0]
        alpha_str = str(alpha_value[1])
        for label_class in tqdm(os.listdir('ImageNetS919/multi_mask')):
            single_res = dict()
            for whole_image in os.listdir('ImageNetS919/multi_mask' + '/' + label_class):
                pth = 'ImageNetS919/multi_mask' + '/' + label_class + '/' + whole_image
                do = False
                for hash_id in todo:
                    if hash_id in pth:
                        do = True
                        break
                if not do:
                    continue
                for class_name in os.listdir(pth):
                    class_name = class_name[:-4]
                    if class_name == 'image':
                        wb_mask = np.ones((224, 224, 3), dtype=np.uint8) * 255
                    else:
                        wb_mask_image = vis_processors["eval"].transform.transforms[0](Image.open(pth + "/" + class_name + '.png'))
                        wb_mask = np.array(wb_mask_image)
                    raw_image = Image.open(pth + "/" + "image.png").convert("RGB")        
                    h, w = wb_mask.shape[:2]
                    image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
                    save_alpha(wb_mask=wb_mask)
                    single_res[class_name] = model.generate({"image": image})[0]
                    save_pth = pth.replace('multi_mask', f'multi_mask_res_blip_flant5xl_l14_0.5_alpha={alpha_str}')
                    os.makedirs(save_pth, exist_ok=True)
                    if class_name == 'image':
                        raw_image.save(save_pth + "/" + f'whole_{single_res[class_name].replace(" ", "_")}.png')
                    else:
                        mask = np.array(Image.open(pth + "/" + class_name + '.png'))
                        mask[:, :, 1:] = 0
                        focal_image = (np.array(raw_image) * 0.5 + mask * 0.5).astype(np.uint8)
                        cv2.imwrite(save_pth + "/" + f'local_{single_res[class_name].replace(" ", "_")}.png',
                                    cv2.cvtColor(focal_image, cv2.COLOR_RGB2BGR))
                print()
                single_res['image'] = label_class + "_" + whole_image
                print(single_res)
                print("")
                res.append(single_res)
    json.dump(res, open('alpha_clip_l_blip_flant5xl_imags.json', 'w'), indent=4)
Akshay1-6180 commented 7 months ago

Thanks for the code and the quick response ,I assumed the model also used blip2 with alpha clip for VQA(https://arxiv.org/abs/2305.06500) but going through the code it has only been used to generate captions, but since it has been integrated with llava , integrating with blip2 for VQA should also be possible , so then the input would be the mask and the image to blip2 similar to LLAVA. Moreover wanted to know ur opinion whether ur idea to use llava instead of blip2 for VQA was based on internally noticing that llava gave better results compared to blip2.

Screenshot 2024-02-02 at 10 31 08 PM

Based on the above picture I assume that for the image level masking result and the feature level masking result you use a normal pretrained blip2 model and for the alpha clip as mentioned u use the alpha map with the alpha clip and integrate that into the blip2 .Based on the code above i hope this is what ur trying to achieve internally This alpha map is not sent as a separate input alongside the image in the conventional sense (like two images being input into the model simultaneously). Instead, it is used within the model's internal processing, specifically within a modified version of the model's forward pass function (rewrited_forward). Here, the alpha map modulates the processing of the image by influencing certain layers (e.g., conv1_alpha) to adjust their behavior based on the information provided by the alpha map. This process allows the model to "pay more attention" to or "focus on" the areas highlighted by the alpha map.

Instead of this since alpha map is anyway trained with the mask and the image in the alpha clip training could they also not be sent together and then integrate that with blip2 but then i assume the blip2 wont be tuned with the alpha clip features so then u would have to train them again end to end.So thats why adding the features of conv1_alpha and input image features were done.

Thank you

SunzeY commented 7 months ago

Hi! the code of inference BLIP-2 for captioning and for VQA is similar. I don't think there is a need to also release code for VQA. As for choosing LLaVA as llm, it's because as the max training context length become longer(128 for Flant5, 2048 for llama2). It is possible to input all image token feature into LLM without the need of using Q-former to compress (or extract) information. Following this trend. We believe LLaVA is more advanced VLM compared to BLIP-2 and Instruct BLIP. So we only test the main result based on LLaVA. I also agree with you by input two image into LLM is also a good way to achieve region focus.

Akshay1-6180 commented 7 months ago

Thank you for the clarification! Looking forward to seeing your future papers on how these advancements unfold and their impact on the field :).