naver-ai / cl-vs-mim

(ICLR 2023) Official PyTorch implementation of "What Do Self-Supervised Vision Transformers Learn?"
Other
101 stars 8 forks source link

Call for the code of normalized mutual information #5

Closed casiatao closed 10 months ago

casiatao commented 10 months ago

Could you release the code that calculated the normalized mutual information? Thank you very much!

xxxnell commented 10 months ago

Hi @casiatao,

Thank you for reaching out. The following provides a code for calculating normalized mutual information.

def calculate_nmi(attn): 
    """ Normalized mutual information with a return type of (batch, head) """
    b, h, q, k = attn.shape
    pq = torch.ones([b, h, q]).to(attn.device)
    pq = F.softmax(pq, dim=-1)
    pq_ext = repeat(pq, "b h q -> b h q k", k=k)
    pk = reduce(attn * pq_ext, "b h q k -> b h k", "sum")
    pk_ext = repeat(pk, "b h k -> b h q k", q=q)

    mi = reduce(attn * pq_ext * torch.log(attn / pk_ext), "b h q k -> b h", "sum")
    eq = - reduce(pq * torch.log(pq), "b h q -> b h", "sum")
    ek = - reduce(pk * torch.log(pk), "b h k -> b h", "sum")

    nmiv = mi / torch.sqrt(eq * ek)

    return nmiv

You can find further information at https://github.com/naver-ai/cl-vs-mim/blob/main/self_attention_analysis.ipynb.

casiatao commented 10 months ago

Thank you for your quick reply. I tried to write the code by myself yesterday, which is consistent with the results in the paper.