Closed casiatao closed 10 months ago
Hi @casiatao,
Thank you for reaching out. The following provides a code for calculating normalized mutual information.
def calculate_nmi(attn):
""" Normalized mutual information with a return type of (batch, head) """
b, h, q, k = attn.shape
pq = torch.ones([b, h, q]).to(attn.device)
pq = F.softmax(pq, dim=-1)
pq_ext = repeat(pq, "b h q -> b h q k", k=k)
pk = reduce(attn * pq_ext, "b h q k -> b h k", "sum")
pk_ext = repeat(pk, "b h k -> b h q k", q=q)
mi = reduce(attn * pq_ext * torch.log(attn / pk_ext), "b h q k -> b h", "sum")
eq = - reduce(pq * torch.log(pq), "b h q -> b h", "sum")
ek = - reduce(pk * torch.log(pk), "b h k -> b h", "sum")
nmiv = mi / torch.sqrt(eq * ek)
return nmiv
You can find further information at https://github.com/naver-ai/cl-vs-mim/blob/main/self_attention_analysis.ipynb.
Thank you for your quick reply. I tried to write the code by myself yesterday, which is consistent with the results in the paper.
Could you release the code that calculated the normalized mutual information? Thank you very much!