salesforce / LAVIS

LAVIS - A One-stop Library for Language-Vision Intelligence
BSD 3-Clause "New" or "Revised" License
9.3k stars 921 forks source link

Text Localization with Blip2 #549

Open dip9811111 opened 9 months ago

dip9811111 commented 9 months ago

Starting from the tutorial link and considering the function compute_gradcam in BlipITM link I'm trying to obtain the same result but using Blip2ITM. Function getAttMap is at link.

This is my code:

def compute_gradcam_new(model, visual_input, text_input, tokenized_text, block_num=None)    
    target_layer = model.Qformer.bert.encoder.layer[block_num].crossattention 
    target_layer.self.save_attention = True

    output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
    loss = output[:, 1].sum()

    model.zero_grad()
    loss.backward()

    with torch.no_grad():
        mask = tokenized_text.attention_mask.view(
            tokenized_text.attention_mask.size(0), 1, -1, 1, 1
        ) 

        token_length = tokenized_text.attention_mask.sum(dim=-1) - 2
        token_length = token_length.cpu()

        grads = target_layer.self.get_attn_gradients()
        cams = target_layer.self.get_attention_map()

        cams = cams[:, :,:mask.shape[2], 1:].reshape(visual_input.size(0), 12, -1, 16, 16) * mask
        grads = grads[:, :, :mask.shape[2], 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 16, 16) * mask

        gradcams = cams * grads
        gradcam = gradcams[0].mean(0).cpu().detach()

        text_input = tokenized_text
        rgb_image = cv2.imread(image_path)[:, :, ::-1]
        rgb_image = np.float32(rgb_image) / 255

        folder_path_images = ".../folderImages"

        for i, token_id in enumerate(text_input.input_ids[0][:]):
            word = tokenizer.decode([token_id])
            word = word.replace("##", "")
            gradcam_image = getAttMap(rgb_image, gradcam[i])
            fig_, ax_ = plt.subplots(1, 1, figsize=(15,5))
            ax_.imshow(gradcam_image)
            ax_.set_yticks([])
            ax_.set_xticks([])
            ax_.set_xlabel(word)
            path_save_image = f"{folder_path_images}/{i}.png"
            fig_.savefig(path_save_image, bbox_inches='tight')

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
model_original, image_processor, text_processor = load_model_and_preprocess("blip2_image_text_matching", "pretrain", 
                                                                   device=device, 
                                                                   is_eval=True,
                                                                  )

image_path = ".../tryImage.jpg"
image = Image.open(image_path).convert('RGB') 

visual_input = torch.stack([image_processor['eval'](image)]).to(device)
text_input = "The sun shines on the colosseum in rome"
text_input = text_processor["eval"](text_input)

output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
tokenized_text = tokenizer(text_input, return_tensors="pt").to(device)
compute_gradcam_new(model, visual_input, text_input, tokenized_text, block_num=10)

Where I considered as target layer model.Qformer.bert.encoder.layer[10]. What I got is different from BlipITM is that cams and grads have a dynamical shape [1, 12, N, 577], where N is the number of tokens of the input text.

Instead, in Blip2ITM the QFormer appears to be instantiated with num_query_token=32. So now grads and cams are always in the form of [1, 12, 32, 257].

For example using that input text, I got:

cams.shape  = torch.Size([1, 12, 32, 257])
grads.shape = torch.Size([1, 12, 32, 257])
mask.shape = torch.Size([1, 1, 8, 1, 1])

So to multiply grads cams mask I tried to consider only the first N (mask.shape[2]):

cams = cams[:, :,:mask.shape[2], 1:].reshape(visual_input.size(0), 12, -1, 16, 16) * mask
grads = grads[:, :, :mask.shape[2], 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 16, 16) * mask

Doing this I got no error but the Grad-CAM is awful and doesn't make sense at all. What's wrong with this?

PPPP-kaqiu commented 4 months ago

this cross attention just relates the query

changbaozhou commented 1 month ago

have you ever solved this problem? I have tried to pad the dimension of mask to 32, but the result seems make no sense.