suous / cs224n

CS224N: Natural Language Processing with Deep Learning, Stanford / Winter 2023
4 stars 0 forks source link

Confusion regarding Causal Cross Attention #1

Open arnavc1712 opened 1 month ago

arnavc1712 commented 1 month ago

Hi, so in the Causal Cross Attention, I see we are registering a causal mask being the lower triangular matrix. However when we are trying to learn the latent parameters C of seq_len m such that m < block_size. But the way it picks the mask seems weird, because it take the first m rows and columns based on the input sequence length.

For example lets say the mask for a block size of 5 is

False True True True True False False True True True False False False True True False False False False True False False False False False

now lets say the C parameter we want to learn is of seq_len 2

Now given the line att = att.masked_fill(self.mask[:,:,:Tq,:Tk] == 0, -1e10) where Tq=2 in our case and Tk=5

we get the mask as (first two rows and 5 columns) False True True True True False False True True True

This mean we only take into account the first value vector to learn the first latent element and the first 2 value vectors to learn the second latent element. And disregard all the other value vectors.

Unless im understanding this incorrectly?

suous commented 1 month ago

Thank you for asking such a detailed question. I haven’t looked at this code for a long time. I’ll try my best to explain it. If you find that my description is not particularly clear, it means that I haven’t fully understood it. You can share your thoughts, and that might help me fully grasp it.

  1. First we register a mask matrix with longer enough sequence length for efficiency, the mask matrix is used to hide future tokens (which we are going to predict).

attention 001

  1. For causal self-attention, we predict subsequent output tokens based on previous input tokens.

attention 002 attention 003

attention 004 attention 005

  1. During inference, the decoder predicts the next token based on the previous decode sequence, while the encoder provides enough information about the encode sequence.

attention 006