HaoZhongkai / GNOT

26 stars 14 forks source link

Does padding in GNOT contaminate the attention matrix? #12

Open BraveDrXuTF opened 1 month ago

BraveDrXuTF commented 1 month ago

In NLP, we have mask machanism to help prevent this. But in GNOT, the following code in https://github.com/HaoZhongkai/GNOT/blob/master/models/cgpt.py seems no mask procedure

class LinearAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super(LinearAttention, self).__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd)

        self.n_head = config.n_head

        self.attn_type = 'l1'

    '''
        Linear Attention and Linear Cross Attention (if y is provided)
    '''
    def forward(self, x, y=None, layer_past=None):
        y = x if y is None else y
        B, T1, C = x.size()
        _, T2, _ = y.size()
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.query(x).view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        k = self.key(y).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = self.value(y).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        if self.attn_type == 'l1':
            q = q.softmax(dim=-1)
            k = k.softmax(dim=-1)   #
            k_cumsum = k.sum(dim=-2, keepdim=True)
            D_inv = 1. / (q * k_cumsum).sum(dim=-1, keepdim=True)       # normalized
        elif self.attn_type == "galerkin":
            q = q.softmax(dim=-1)
            k = k.softmax(dim=-1)  #
            D_inv = 1. / T2                                           # galerkin
        elif self.attn_type == "l2":                                   # still use l1 normalization
            q = q / q.norm(dim=-1,keepdim=True, p=1)
            k = k / k.norm(dim=-1,keepdim=True, p=1)
            k_cumsum = k.sum(dim=-2, keepdim=True)
            D_inv = 1. / (q * k_cumsum).abs().sum(dim=-1, keepdim=True)  # normalized
        else:
            raise NotImplementedError

        context = k.transpose(-2, -1) @ v
        y = self.attn_drop((q @ context) * D_inv + q)

        # output projection
        y = rearrange(y, 'b h n d -> b n (h d)')
        y = self.proj(y)
        return y

class LinearCrossAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super(LinearCrossAttention, self).__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.keys = nn.ModuleList([nn.Linear(config.n_embd, config.n_embd) for _ in range(config.n_inputs)])
        self.values = nn.ModuleList([nn.Linear(config.n_embd, config.n_embd) for _ in range(config.n_inputs)])
        # regularization
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd)

        self.n_head = config.n_head
        self.n_inputs = config.n_inputs

        self.attn_type = 'l1'

    '''
        Linear Attention and Linear Cross Attention (if y is provided)
    '''
    def forward(self, x, y=None, layer_past=None):
        y = x if y is None else y
        B, T1, C = x.size()
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.query(x).view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.softmax(dim=-1)
        out = q
        for i in range(self.n_inputs):
            _, T2, _ = y[i].size()
            k = self.keys[i](y[i]).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
            v = self.values[i](y[i]).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
            k = k.softmax(dim=-1)  #
            k_cumsum = k.sum(dim=-2, keepdim=True)
            D_inv = 1. / (q * k_cumsum).sum(dim=-1, keepdim=True)  # normalized
            out = out +  1 * (q @ (k.transpose(-2, -1) @ v)) * D_inv

        # output projection
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.proj(out)
        return out

. So it seems the element in the attention matrix is contaminated by the padded part. Is it true? Thanks.

HaoZhongkai commented 1 month ago

Hi, I'm not sure what type of padding you are mentioning. We pad sequences of different lengths to the same length using 0 padding. So this padding does not influence the attention computation. And we do not need causal masking for attention computation on spatial data.

BraveDrXuTF commented 1 month ago

Hi, @HaoZhongkai Thank you for your response. The padding I mention is just padding operations on sequences of different lengths.

After K, V mapping, because you do not set the bias of K V linear layer to zero, the result at padded place of the tensors after K V linear transform would not be zero either.

Of course not zero might not be a big problem...

BraveDrXuTF commented 4 weeks ago

But I suggest it is still better adding a mask for K V Q so that the attention matrix would not be effected by these unreal padded positions.

In my experiments, as you can see, even in the first attention block of GNOT, the last few elements in the total length of the k has become slightly different. This might be caused by numerical error or something,

(Pdb) k[:,10,-2:,10:14]
tensor([[[ 3.0466e-04, -3.8661e-04, -1.5450e-04,  5.3415e-06],
         [ 3.0456e-04, -3.8309e-04, -1.5212e-04,  3.2105e-06]],
        [[ 3.0484e-04, -3.8515e-04, -1.5213e-04,  7.4536e-06],
         [ 3.0420e-04, -3.8171e-04, -1.5517e-04,  6.6272e-06]],

        [[ 2.9967e-04, -3.8171e-04, -1.4730e-04,  6.0417e-06],
         [ 2.9956e-04, -3.8265e-04, -1.4569e-04,  6.8420e-06]],

        [[ 3.0649e-04, -3.8577e-04, -1.5380e-04,  5.6187e-06],
         [ 3.0878e-04, -3.8615e-04, -1.5406e-04,  5.9028e-06]],

        [[ 2.9486e-04, -3.8043e-04, -1.4438e-04,  9.2388e-06],
         [ 2.9771e-04, -3.8179e-04, -1.4543e-04,  7.7717e-06]],

        [[ 3.0080e-04, -3.8102e-04, -1.4764e-04,  9.0702e-06],
         [ 2.9898e-04, -3.8154e-04, -1.4521e-04,  9.6725e-06]],

        [[ 2.9786e-04, -3.7959e-04, -1.4511e-04,  9.8170e-06],
         [ 2.9833e-04, -3.7960e-04, -1.4366e-04,  7.0692e-06]],

        [[ 3.0522e-04, -3.8424e-04, -1.5075e-04,  4.8854e-06],
         [ 3.0750e-04, -3.8452e-04, -1.5250e-04,  7.1328e-06]]],