zjysteven / VLM-Visualizer

Visualizing the attention of vision-language models
51 stars 4 forks source link

Will you do this for siglip the new LLaVA as well? #5

Open trinhvg opened 2 weeks ago

trinhvg commented 2 weeks ago

Will you do this for siglip the new LLaVA as well? Thank you for a nice work.

zjysteven commented 2 weeks ago

Hi thank you for your interest. Unfortunately since I'm about to graduate very soon and will start my full-time job, it's quite unlikely that I will have the time to do it.

Extending it to new llava series model wouldn't be difficult, though, as I would imagine. I can take questions and share my thoughts on github if any difficulties are met.

trinhvg commented 2 weeks ago

Hi, for videos, we input 32 frames, I want to know which frames are more important. I think I should modify this part of your code for that purpose. If you have any suggestions, please help.

identify length or index of tokens

input_token_len = model.get_vision_tower().num_patches + len(input_ids[0]) - 1 # -1 for the token vision_token_start = len(tokenizer(prompt.split("")[0], return_tensors='pt')["input_ids"][0]) vision_token_end = vision_token_start + model.get_vision_tower().num_patches output_token_len = len(outputs["sequences"][0]) output_token_start = input_token_len output_token_end = input_token_len + output_token_len

look at the attention weights over the vision tokens

overall_attn_weights_over_vis_tokens = [] for i, (row, token) in enumerate( zip( llm_attn_matrix[input_token_len:], outputs["sequences"][0].tolist() ) ):

print(

#     i + input_token_len, 
#     f"{tokenizer.decode(token, add_special_tokens=False).strip():<15}", 
#     f"{row[vision_token_start:vision_token_end].sum().item():.4f}"
# )

overall_attn_weights_over_vis_tokens.append(
    row[vision_token_start:vision_token_end].sum().item()
)

plot the trend of attention weights over the vision tokens

fig, ax = plt.subplots(figsize=(20, 5)) ax.plot(overall_attn_weights_over_vis_tokens) ax.set_xticks(range(len(overall_attn_weights_over_vis_tokens))) ax.set_xticklabels( [tokenizer.decode(token, add_special_tokens=False).strip() for token in outputs["sequences"][0].tolist()], rotation=75 ) ax.set_title("at each token, the sum of attention weights over all the vision tokens");

zjysteven commented 1 week ago

Hi would you mind structuring the code in the comment a bit more clear? Also what is the exact question you want to ask here?