NVlabs / RADIO

Official repository for "AM-RADIO: Reduce All Domains Into One"
Other
593 stars 23 forks source link

LLaVA questions #69

Open DiWHNJ opened 2 months ago

DiWHNJ commented 2 months ago

Thank you for your interesting and important work. I noticed that the visualization of LLaVA attention maps was implemented in Figure 8 of the paper. I would like to know if the code for visualizing LLaVA attention maps after inputting prompt and image is open source, but I couldn't find it. Thank you.

gheinrich commented 2 months ago

Hello, sorry for the late response. For this figure of the paper, I dumped attention scores, and the positions of the image tokens within the sequence, in the LLaVA framework. For every generated token and for every layer, this gives me a tensor with shape (batch_size, num_heads, sequence_len, sequence_len). From this tensor, I extract attention scores that are associated with the image and then I reshape them to 2D in order to blend the attention heatmap with the image. The code looks like:

        # x is a tensor of attention scores for the last token in the generated sequence.
        x = x.flatten(start_dim=0, end_dim=-2)
        x = torch.mean(x.flatten(start_dim=0, end_dim=-2), dim=0)
        # num_preamble_tokens is the number of tokens before the vision tokens in the sequence.
        image_attentions = x[num_preamble_tokens:num_preamble_tokens+num_image_tokens]
        image_attentions = rearrange(image_attentions, '(h w) -> h w', h=24, w=24)
        image_attentions = torch.nn.functional.interpolate(image_attentions.unsqueeze(0).unsqueeze(0), scale_factor=16, mode='bilinear')
        # Convert to numpy.
        image_attentions = image_attentions[0][0].cpu().numpy()
        # Choose a colormap (e.g., 'viridis', 'plasma', 'cividis', 'inferno', 'magma', etc.)
        cmap_name = 'inferno'
        # Normalize the 2D array data to the range [0, 1]
        norm = Normalize(vmin=image_attentions.min(), vmax=image_attentions.max())
        # Create a scalar mappable for mapping the normalized values to colors in the colormap
        scalar_mappable = ScalarMappable(norm=norm, cmap=cmap_name)
        # Map the 2D array data to RGBA values using the chosen colormap
        image_attentions = scalar_mappable.to_rgba(image_attentions)
        # Blend the RGBA values with the RGB image
        w = 0.5
        image_attentions = w * image_attentions[:, :, :3] + (1-w) * input_image