lucidrains / CoCa-pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch
MIT License
1.03k stars 88 forks source link

attn_mask #8

Open pldlgb opened 2 years ago

pldlgb commented 2 years ago
cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')  
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)

attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')  
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

Hello, I am confused of the implement of "attn_mask". I think this padding function only can mask the last row of "sim". Could you please explain it? Perhaps it's a very fool question. Thank you so much.

skyerhxx commented 1 year ago

I have the same question. It seems like the attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) is not right. Based on the original paper, the attn_mask here should be in the form of an inverted triangle, to prevent the current timestep feature from seeing the future timestep feature.

Welcome to discuss.

gshaikov-paige commented 12 months ago

@skyerhxx This is not the causal mask, this is a mask that prevents CLS tokens from attending to PAD tokens in the batch.

We add PAD tokens to the text batch since text examples have different length but the tensor has a fixed dimension, so to concat them into a batch tensor one must pad the end sequence with dummy token, i.e. a PAD token. However, since we append CLS token to the very end, it will attend to the entire sequence, including PAD tokens, which we don't want. So we mask them out.

gshaikov-paige commented 12 months ago

@pldlgb we only mask the last row of sim because this row corresponds to the CLS token query. Without this mask it will attend to all the keys before it, incl. PAD keys.

We don't need to mask other queries because we don't care what PAD queries attend to - they will be masked out when we compute CE loss. We also don't need to mask text queries since they are already masked by the causal mask so they can only look backwards at other text queries.