Closed skyABB closed 7 months ago
YES!代码在这里:
def visualize_similarity(self, a_before_cls_token, a_after_cls_token, b_before_patch_tokens, b_after_patch_tokens, pattern=None):
a_before_cls_token = a_before_cls_token.unsqueeze(1)
similarities_ori = torch.nn.functional.cosine_similarity(a_before_cls_token, b_before_patch_tokens, dim=-1)
similarities_ori = torch.mean(similarities_ori, dim=1).squeeze().cpu().detach().numpy()
a_after_cls_token = a_after_cls_token.unsqueeze(1)
similarities = torch.nn.functional.cosine_similarity(a_after_cls_token, b_after_patch_tokens, dim=-1)
similarities = torch.mean(similarities, dim=1).squeeze().cpu().detach().numpy()
# Set Seaborn style
sns.set(style="whitegrid")
# Create a figure and axis
fig, ax = plt.subplots()
# Plot KDE curves for "before" and "after" fusion
sns.kdeplot(similarities, color='b', label='After HMA', ax=ax, multiple="stack")
sns.kdeplot(similarities_ori, color='g', label='Before HMA', ax=ax, multiple="stack")
if pattern == 'r2t':
sign = 'R2T'
elif pattern == 'r2n':
sign = 'R2N'
elif pattern == 'n2t':
sign = 'N2T'
elif pattern == 'n2r':
sign = 'N2R'
elif pattern == 't2r':
sign = 'T2R'
elif pattern == 't2n':
sign = 'T2N'
plt.title("Similarity Distribution (cls2patch) of " + sign, fontsize=18, fontweight='bold')
plt.xlabel("Cosine Similarity", fontsize=16, fontweight='bold')
plt.ylabel("Density", fontsize=16, fontweight='bold')
# Add a legend to distinguish "before" and "after" fusion
plt.legend(loc='upper right', fontsize=17)
plt.show()
这个工具是否有参考code?