YuchuanTian / DiJiang

[ICML'24 Oral] The official code of "DiJiang: Efficient Large Language Models through Compact Kernelization", a novel DCT-based linear attention mechanism.
https://arxiv.org/abs/2403.19928
86 stars 5 forks source link

Merge to huggingface/transformers #1

Open sepcnt opened 3 months ago

sepcnt commented 3 months ago

Here is a diff version compared with https://github.com/huggingface/transformers/commit/fed27ffc7ec62837dca9bbfc83442eb3678ee026

diff --git a/modeling/pythia-1B-dijiang/modeling_gpt_neox.py b/modeling/pythia-1B-dijiang/modeling_gpt_neox.py
--- a/modeling/pythia-1B-dijiang/modeling_gpt_neox.py
+++ b/modeling/pythia-1B-dijiang/modeling_gpt_neox.py
@@ -40,6 +40,9 @@
 from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
 from .configuration_gpt_neox import GPTNeoXConfig

+import numpy as np
+from scipy.fft import dct
+

 if is_flash_attn_2_available():
     from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -98,7 +101,7 @@

 class GPTNeoXAttention(nn.Module):
-    def __init__(self, config):
+    def __init__(self, config, gamma):
         super().__init__()
         self.config = config
         self.num_attention_heads = config.num_attention_heads
@@ -118,6 +121,17 @@
         self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias)
         self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
         self.attention_dropout = nn.Dropout(config.attention_dropout)
+
+        self.proj_matrix = self._build_projection()
+        self.v_dim = config.hidden_size
+        self.W_G = nn.Parameter(torch.randn(self.hidden_size, self.v_dim) / self.hidden_size)
+        self.swish = nn.SiLU()
+        self.group_norm = nn.GroupNorm(self.head_size, self.v_dim)
+        nn.init.xavier_uniform_(self.W_G.data, gain=2 ** -2.5)
+        self.D1 = self._get_D1(self.config.max_position_embeddings)
+        self.D2 = self._get_D2(self.config.max_position_embeddings)
+        self.mask = self._get_mask(self.config.max_position_embeddings).unsqueeze(0)
+
         self.is_causal = True

     def _init_bias(self, max_positions, device=None):
@@ -155,7 +169,32 @@
                 )
             else:
                 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+            
+    def _get_D1(self, sequence_length):
+        D = ((1 - torch.exp(torch.linspace(torch.log(1/32), torch.log(1/512), self.num_attention_heads))).view(self.num_attention_heads, 1, 1) ** (torch.arange(sequence_length).unsqueeze(1))).float().unsqueeze(0)

+        return nn.Parameter(D, requires_grad=False)
+    
+    def _get_D2(self, sequence_length):
+        D = 1/((1 - torch.exp(torch.linspace(torch.log(1/32), torch.log(1/512), self.num_attention_heads))).view(self.num_attention_heads, 1, 1) ** (torch.arange(sequence_length).unsqueeze(1))).float().unsqueeze(0)
+
+        return nn.Parameter(D, requires_grad=False)
+
+    def _get_mask(self, sequence_length):
+        n = torch.arange(sequence_length).unsqueeze(1)
+        m = torch.arange(sequence_length).unsqueeze(0)
+
+        M = torch.ones(self.num_attention_heads).view(self.num_attention_heads, 1, 1)*(n >= m).float() 
+
+        return M
+
+    def _build_projection(self):
+        icdf_w = torch.distributions.Normal(0, 1).icdf(torch.diag_embed(torch.diag(torch.rand(self.head_size, self.head_size))))
+        icdf_w = torch.where(torch.isinf(icdf_w), torch.full_like(icdf_w, 0), icdf_w)
+        C = dct(np.eye(self.head_size, self.head_size), axis=0,norm='ortho')
+        C = torch.from_numpy(C).type(torch.FloatTensor)
+        return nn.Parameter((C @ icdf_w).contiguous(), requires_grad=False)
+
     def forward(
         self,
         hidden_states: torch.FloatTensor,
@@ -212,7 +251,8 @@

         # Reshape outputs
         attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
-        attn_output = self.dense(attn_output)
+        attn_output = self.group_norm(attn_output.reshape(-1, self.v_dim)).reshape(attn_output.shape)
+        attn_output = self.dense(self.swish(hidden_states @ self.W_G) * attn_output)

         outputs = (attn_output, present)
         if output_attentions:
@@ -256,36 +296,16 @@
             self._init_bias(key_length, device=key.device)
         causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

-        query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
-        key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
-        attn_scores = torch.zeros(
-            batch_size * num_attention_heads,
-            query_length,
-            key_length,
-            dtype=query.dtype,
-            device=key.device,
-        )
-        attn_scores = torch.baddbmm(
-            attn_scores,
-            query,
-            key.transpose(1, 2),
-            beta=1.0,
-            alpha=self.norm_factor,
-        )
-        attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
+        query = nn.functional.softmax(query@self.proj_matrix, dim=-1)
+        key =  nn.functional.softmax(key@self.proj_matrix, dim=-1)
+        query = query*self.D1[:,:,:query_length,:]
+        key = key*self.D2[:,:,:key_length,:]

-        mask_value = torch.finfo(attn_scores.dtype).min
-        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
-        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
-        mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
-        attn_scores = torch.where(causal_mask, attn_scores, mask_value)
+        attn_scores = torch.matmul(query, key.transpose(2, 3))

-        if attention_mask is not None:
-            # Apply the attention mask
-            attn_scores = attn_scores + attention_mask
+        attn_scores = attn_scores * self.mask[:,:,:query_length,:key_length].to(attn_scores.device,dtype=attn_scores.dtype)

-        attn_weights = nn.functional.softmax(attn_scores, dim=-1)
-        attn_weights = attn_weights.to(value.dtype)
+        attn_weights = attn_scores.to(value.dtype)

         # Mask heads if we want to
         if head_mask is not None:
@@ -667,14 +687,14 @@

 class GPTNeoXLayer(nn.Module):
-    def __init__(self, config):
+    def __init__(self, config, gamma):
         super().__init__()
         self.use_parallel_residual = config.use_parallel_residual
         self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
         self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
-        self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config)
+        self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config, gamma)
         self.mlp = GPTNeoXMLP(config)

     def forward(
@@ -787,7 +807,8 @@

         self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
         self.emb_dropout = nn.Dropout(config.hidden_dropout)
-        self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gammas = (1 - torch.exp(torch.linspace(torch.log(1/32), torch.log(1/512), config.num_hidden_layers))).detach().cpu().tolist()
+        self.layers = nn.ModuleList([GPTNeoXLayer(config, gamma) for gamma in self.gammas])
         self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"

As for DCT, perhaps could refer to https://github.com/zh217/torch-dct/blob/master/torch_dct/_dct.py