dhansmair / flamingo-mini

Implementation of the deepmind Flamingo vision-language model, based on Hugging Face language models and ready for training
MIT License
163 stars 16 forks source link

Doubt about MaskedCrossAttention #18

Open eileenforwhat opened 1 year ago

eileenforwhat commented 1 year ago

Hi, I'm unsure about this piece of code in MaskedCrossAttention inside gated_cross_attention.py

media_time = torch.arange(n_media, device=y.device) + 1
# >> David:
# side note: here, text tokens attend to ALL previous visual tokens. If We only want to attend to the
# one image coming before in the text (like in the flamingo paper),
# we need to change >= to == at the line where 'text_to_media_mask' is created.
text_to_media_mask = rearrange(text_time, 'b i -> b 1 i 1') == repeat(media_time, 'j -> 1 1 1 (j m)', m=self.n_visual)
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)

sim = sim - sim.amax(dim=-1, keepdim=True).detach()

It seems you are setting the positions you want to mask out to -torch.finfo(sim.dtype).max (large negative number), but then finding the largest value sim.amax to normalize by?

I would think it should be:

sim = sim.masked_fill(~text_to_media_mask, torch.finfo(sim.dtype).max)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()

Any clarification on this logic is appreciated. Thanks!

dhansmair commented 1 year ago

Hi @eileenforwhat, it's been a while since I have worked on this, so I needed to think about it myself. This snippet is taken from lucidrains code: https://github.com/lucidrains/flamingo-pytorch/blob/10913abbc8b2ceabb2320560d7d9b85fcb85eee3/flamingo_pytorch/flamingo_pytorch.py#L170 where he does the same.

consider this toy example:

import torch

mask = torch.tensor([0,0,1,1], dtype=bool)
x = torch.tensor([1,2,3,4], dtype=float)
print("mask:", mask)
print("inverted mask:", ~mask)
x = x.masked_fill(~mask, -torch.finfo(x.dtype).max)
print("x:", x)
x = x - x.amax(dim=-1, keepdim=True).detach()
print("x:", x)
alphas = x.softmax(dim=-1)
print("alphas: ", alphas)

which gives this result:

 $ python test.py
mask: tensor([False, False,  True,  True])
inverted mask: tensor([ True,  True, False, False])
x: tensor([-1.7977e+308, -1.7977e+308,   3.0000e+00,   4.0000e+00],
       dtype=torch.float64)
x: tensor([-1.7977e+308, -1.7977e+308,  -1.0000e+00,   0.0000e+00],
       dtype=torch.float64)
alphas:  tensor([0.0000, 0.0000, 0.2689, 0.7311], dtype=torch.float64)

here, subtracting the maximum value of 4 is not "normalizing", it shifts the largest value to zero. In fact, it does not change the result of the softmax operation, so my assumption is that it is done for numerical stability (..?)

note that setting the values we want to mask to -infinity will result in 0 after the softmax operation, which is what we want to achieve.

Hope this helps!

eileenforwhat commented 1 year ago

I see. This makes sense -- Thank you!

dhansmair commented 1 year ago

sure, feel free to ask if you have any more doubts :)