kai-wen-yang / QVix

Good Questions Help Zero-Shot Image Reasoning
BSD 3-Clause "New" or "Revised" License
10 stars 1 forks source link

Visualization of attention_map in Paper #1

Open ShogoAkiyama opened 7 months ago

ShogoAkiyama commented 7 months ago

@kai-wen-yang Thank you for the excellent paper and code.

Could you please provide some information on how you visualized the attention_map mentioned in the paper? Your assistance would be greatly appreciated.

Image similar to this https://github.com/kai-wen-yang/QVix/blob/ed266a879e9e20e57c2971370b453d9f0ef5c60a/images/case1.png

Thank you.

kai-wen-yang commented 7 months ago

Hi, thanks for your interest. We follow this code to visualize the cross-attention map: https://github.com/salesforce/LAVIS/blob/main/examples/blip_text_localization.ipynb

ShogoAkiyama commented 7 months ago

Thank you for your response.

The code for Lavis is BIT, so I tried to make modifications for GradCAM visualization for InstructBLIP. However, I'm facing difficulties with generating the attention map. Would it be possible for you to share the code?

Your assistance would be greatly appreciated.

my code

    model.qformer.encoder.layer[block_num].crossattention.attention.save_attention = True

    output = model.image2text_match(**inputs)
    loss = output[:, 1].sum()

    model.zero_grad()
    loss.backward()
    with torch.no_grad():
        # attention_mask = inputs["qformer_attention_mask"]
        query_attention_mask = torch.ones((1,32), dtype=torch.long, device=device)
        mask = query_attention_mask.view(query_attention_mask.size(0), 1, -1, 1, 1)  # (bsz,1,token_len, 1,1)
        token_length = query_attention_mask.sum(dim=-1) - 2
        token_length = token_length.cpu()

        # grads and cams [bsz, num_head, seq_len, image_patch]
        grads = model.qformer.encoder.layer[block_num].crossattention.attention.get_attn_gradients()
        cams = model.qformer.encoder.layer[block_num].crossattention.attention.get_attention_map()

        cams = (cams[:, :, :, 1:].reshape(torch_image.size(0), 12, -1, 16, 16) * mask)
        grads = (grads[:, :, :, 1:].clamp(0).reshape(torch_image.size(0), 12, -1, 16, 16) * mask)

left: Lavis BLIP, right: InstructBLIP(I modified)

スクリーンショット 2024-01-16 18 22 34