924973292 / EDITOR

【CVPR2024】Magic Tokens: Select Diverse Tokens for Multi-modal Object Re-Identification
MIT License
82 stars 5 forks source link

可视化工具Similarity怎么做的? #1

Closed skyABB closed 7 months ago

skyABB commented 7 months ago

这个工具是否有参考code?

924973292 commented 7 months ago

YES!代码在这里:

def visualize_similarity(self, a_before_cls_token, a_after_cls_token, b_before_patch_tokens, b_after_patch_tokens, pattern=None):
        a_before_cls_token = a_before_cls_token.unsqueeze(1)
        similarities_ori = torch.nn.functional.cosine_similarity(a_before_cls_token, b_before_patch_tokens, dim=-1)
        similarities_ori = torch.mean(similarities_ori, dim=1).squeeze().cpu().detach().numpy()

        a_after_cls_token = a_after_cls_token.unsqueeze(1)
        similarities = torch.nn.functional.cosine_similarity(a_after_cls_token, b_after_patch_tokens, dim=-1)
        similarities = torch.mean(similarities, dim=1).squeeze().cpu().detach().numpy()
        # Set Seaborn style
        sns.set(style="whitegrid")

        # Create a figure and axis
        fig, ax = plt.subplots()

        # Plot KDE curves for "before" and "after" fusion
        sns.kdeplot(similarities, color='b', label='After HMA', ax=ax, multiple="stack")
        sns.kdeplot(similarities_ori, color='g', label='Before HMA', ax=ax, multiple="stack")
        if pattern == 'r2t':
            sign = 'R2T'
        elif pattern == 'r2n':
            sign = 'R2N'
        elif pattern == 'n2t':
            sign = 'N2T'
        elif pattern == 'n2r':
            sign = 'N2R'
        elif pattern == 't2r':
            sign = 'T2R'
        elif pattern == 't2n':
            sign = 'T2N'
        plt.title("Similarity Distribution (cls2patch) of " + sign, fontsize=18, fontweight='bold')
        plt.xlabel("Cosine Similarity", fontsize=16, fontweight='bold')
        plt.ylabel("Density", fontsize=16, fontweight='bold')

        # Add a legend to distinguish "before" and "after" fusion
        plt.legend(loc='upper right', fontsize=17)
        plt.show()