karpathy / nanoGPT

The simplest, fastest repository for training/finetuning medium-sized GPTs.
MIT License
37.55k stars 5.99k forks source link

How best to implement a differential transformer? #567

Open Wilsontomass opened 1 month ago

Wilsontomass commented 1 month ago

I'm not sure issues is the greatest place to post this but I just wanted to see if anyone else had been trying this idea:

There was a paper that came out recently that proposed a new head architecture, and I wanted to see if I could replicate the results (according to the paper they are very promising). It didn't seem too hard given what I knew from messing around with this repo. The authors provided 3 versions of the code here and to keep things simple I tried to use this implementation here. I added rotary positional encoding separately and tested that, it worked well, and then I added the differential mechanism, my code looks like this:

class CausalSelfAttention(nn.Module):
    def __init__(self, config, depth):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head // 2  # div by 2 because each head is larger, so we only have half as many
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        self.head_dim = self.n_embd // self.n_head // 2 # div by 2 because double key and query
        self.rotary_emb = RotaryEmbedding(dim=self.head_dim, max_position_embeddings=config.block_size)  # Added line

        self.lambda_init = lambda_init_fn(depth)
        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=False)
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        k = k.view(B, T, self.n_head*2, self.head_dim)  # (B, T, nh, hs)
        q = q.view(B, T, self.n_head*2, self.head_dim)  # (B, T, nh, hs)
        v = v.view(B, T, self.n_head, 2, self.head_dim)  # (B, T, nh, hs)
        # Apply rotary embeddings to q and k
        cos, sin = self.rotary_emb(q, seq_len=T)
        q = apply_rotary_pos_emb(q, cos, sin)
        k = apply_rotary_pos_emb(k, cos, sin)
        q = q.reshape(B, T, self.n_head, 2, self.head_dim)
        k = k.reshape(B, T, self.n_head, 2, self.head_dim)
        q1, q2 = q[:, :, :, 0], q[:, :, :, 1]
        k1, k2 = k[:, :, :, 0], k[:, :, :, 1]
        v1, v2 = v[:, :, :, 0], v[:, :, :, 1]
        attn11 = F.scaled_dot_product_attention(q1, k1, v1, attn_mask=None, is_causal=True)
        attn12 = F.scaled_dot_product_attention(q1, k1, v2, attn_mask=None, is_causal=True)
        attn1 = torch.cat([attn11, attn12], dim=-1)
        attn21 = F.scaled_dot_product_attention(q2, k2, v1, attn_mask=None, is_causal=True)
        attn22 = F.scaled_dot_product_attention(q2, k2, v2, attn_mask=None, is_causal=True)
        attn2 = torch.cat([attn21, attn22], dim=-1)

        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
        lambda_full = lambda_1 - lambda_2 + self.lambda_init
        attn = attn1 - lambda_full * attn2
        attn = self.subln(attn)
        attn = attn * (1 - self.lambda_init)
        attn = attn.reshape(B, T, C)
        # output projection
        y = self.resid_dropout(self.c_proj(attn))
        return y

When i try and train this model it understandably trains at a lower iterations/sec, but if we look at the loss per iteration it seems to be getting stuck. (in each iteration i have kept the total batch size as compared to the gpt2-124M-RoPE run) image Any ideas on what I've gotten wrong? I'm no ML expert

@karpathy on the off chance that you see this, have you read about the diff transformer paper and if so, what do you think about it?

notlober commented 2 weeks ago

ive done similar experiment too, first of all I recommend looking at non flash implementation from diff attn: "multihead_diffattn.py" and not "multihead_flashdiff_1.py"

secondly youre dividing n_heads and head dim twice, its a issue, it does not appear in original code. first here: self.n_head = config.n_head // 2 and here: self.head_dim = self.n_embd // self.n_head // 2

lastly, even through standard gpt2 with RoPE went well, I recommend starting with non RoPE version, since its easier to begin with.

and probably F.scaled_dot_product_attention is a bit different that flash attention internally.

notlober commented 2 weeks ago

my implementation:

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.head_dim = self.n_embd // self.n_head // 2
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                    .view(1, 1, config.block_size, config.block_size))
        self.lambda_init = lambda_init_fn(config.n_layer)
        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))

        self.subln = nn.LayerNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True)

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, 2*self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, 2*self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, 2*self.head_dim).transpose(1, 2) # (B, nh, T, hs)

        # manual implementation of attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1))
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1))
        lambda_full = lambda_1 - lambda_2 + self.lambda_init
        lambda_full = lambda_full.view(1, self.n_head, 1, 1) 

        att = att.view(B, self.n_head, 2, T, T)
        attn1 = att[:, :, 0, :, :]  # (B, nh, T, T)
        attn2 = att[:, :, 1, :, :]  # (B, nh, T, T)

        attn_weights = attn1 - lambda_full * attn2  # (B, nh, T, T)

        att = self.attn_dropout(attn_weights)
        att = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        att = self.subln(att)
        att = att * (1 - self.lambda_init)
        y = att.permute(0, 2, 1, 3).contiguous().view(B, T, C)  # (B, T, n_embd)
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y