Open Wilsontomass opened 1 month ago
ive done similar experiment too, first of all I recommend looking at non flash implementation from diff attn: "multihead_diffattn.py" and not "multihead_flashdiff_1.py"
secondly youre dividing n_heads and head dim twice, its a issue, it does not appear in original code. first here: self.n_head = config.n_head // 2 and here: self.head_dim = self.n_embd // self.n_head // 2
lastly, even through standard gpt2 with RoPE went well, I recommend starting with non RoPE version, since its easier to begin with.
and probably F.scaled_dot_product_attention is a bit different that flash attention internally.
my implementation:
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.head_dim = self.n_embd // self.n_head // 2
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
self.lambda_init = lambda_init_fn(config.n_layer)
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.subln = nn.LayerNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True)
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, 2*self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, 2*self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, 2*self.head_dim).transpose(1, 2) # (B, nh, T, hs)
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1))
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1))
lambda_full = lambda_1 - lambda_2 + self.lambda_init
lambda_full = lambda_full.view(1, self.n_head, 1, 1)
att = att.view(B, self.n_head, 2, T, T)
attn1 = att[:, :, 0, :, :] # (B, nh, T, T)
attn2 = att[:, :, 1, :, :] # (B, nh, T, T)
attn_weights = attn1 - lambda_full * attn2 # (B, nh, T, T)
att = self.attn_dropout(attn_weights)
att = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
att = self.subln(att)
att = att * (1 - self.lambda_init)
y = att.permute(0, 2, 1, 3).contiguous().view(B, T, C) # (B, T, n_embd)
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
I'm not sure issues is the greatest place to post this but I just wanted to see if anyone else had been trying this idea:
There was a paper that came out recently that proposed a new head architecture, and I wanted to see if I could replicate the results (according to the paper they are very promising). It didn't seem too hard given what I knew from messing around with this repo. The authors provided 3 versions of the code here and to keep things simple I tried to use this implementation here. I added rotary positional encoding separately and tested that, it worked well, and then I added the differential mechanism, my code looks like this:
When i try and train this model it understandably trains at a lower iterations/sec, but if we look at the loss per iteration it seems to be getting stuck. (in each iteration i have kept the total batch size as compared to the gpt2-124M-RoPE run) Any ideas on what I've gotten wrong? I'm no ML expert
@karpathy on the off chance that you see this, have you read about the diff transformer paper and if so, what do you think about it?