Open arnavc1712 opened 1 month ago
Thank you for asking such a detailed question. I haven’t looked at this code for a long time. I’ll try my best to explain it. If you find that my description is not particularly clear, it means that I haven’t fully understood it. You can share your thoughts, and that might help me fully grasp it.
Hi, so in the Causal Cross Attention, I see we are registering a causal mask being the lower triangular matrix. However when we are trying to learn the latent parameters C of seq_len m such that m < block_size. But the way it picks the mask seems weird, because it take the first m rows and columns based on the input sequence length.
For example lets say the mask for a block size of 5 is
False True True True True False False True True True False False False True True False False False False True False False False False False
now lets say the C parameter we want to learn is of seq_len 2
Now given the line att = att.masked_fill(self.mask[:,:,:Tq,:Tk] == 0, -1e10) where Tq=2 in our case and Tk=5
we get the mask as (first two rows and 5 columns) False True True True True False False True True True
This mean we only take into account the first value vector to learn the first latent element and the first 2 value vectors to learn the second latent element. And disregard all the other value vectors.
Unless im understanding this incorrectly?