facebookresearch / convit

Code for the Convolutional Vision Transformer (ConViT)
Apache License 2.0
459 stars 53 forks source link

About Nonlocality #23

Open yangbang18 opened 2 years ago

yangbang18 commented 2 years ago

Thanks for your great work and codes.

I am a little bit confused about your implementations of nonlocality (in main.py (L346-351))

Here is the code:

batch = next(iter(data_loader_val))[0]
batch = batch.to(device)
batch = model_without_ddp.patch_embed(batch)
for l in range(len(model_without_ddp.blocks)):
    attn =  model_without_ddp.blocks[l].attn
    nonlocality[l] = attn.get_attention_map(batch).detach().cpu().numpy().tolist()

It seems that you always feed the original patch embeddings to all 12 blocks. Shouldn't the inputs of attn.get_attention_map be [original patch embeddings, outputs of the block 1, ..., outputs of the block 11]?

If I understand it wrong, please correct me.

Sincerely, looking forward to your reply.