Open sepcnt opened 3 months ago
Here is a diff version compared with https://github.com/huggingface/transformers/commit/fed27ffc7ec62837dca9bbfc83442eb3678ee026
diff --git a/modeling/pythia-1B-dijiang/modeling_gpt_neox.py b/modeling/pythia-1B-dijiang/modeling_gpt_neox.py --- a/modeling/pythia-1B-dijiang/modeling_gpt_neox.py +++ b/modeling/pythia-1B-dijiang/modeling_gpt_neox.py @@ -40,6 +40,9 @@ from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging from .configuration_gpt_neox import GPTNeoXConfig +import numpy as np +from scipy.fft import dct + if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -98,7 +101,7 @@ class GPTNeoXAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, gamma): super().__init__() self.config = config self.num_attention_heads = config.num_attention_heads @@ -118,6 +121,17 @@ self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias) self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.attention_dropout = nn.Dropout(config.attention_dropout) + + self.proj_matrix = self._build_projection() + self.v_dim = config.hidden_size + self.W_G = nn.Parameter(torch.randn(self.hidden_size, self.v_dim) / self.hidden_size) + self.swish = nn.SiLU() + self.group_norm = nn.GroupNorm(self.head_size, self.v_dim) + nn.init.xavier_uniform_(self.W_G.data, gain=2 ** -2.5) + self.D1 = self._get_D1(self.config.max_position_embeddings) + self.D2 = self._get_D2(self.config.max_position_embeddings) + self.mask = self._get_mask(self.config.max_position_embeddings).unsqueeze(0) + self.is_causal = True def _init_bias(self, max_positions, device=None): @@ -155,7 +169,32 @@ ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _get_D1(self, sequence_length): + D = ((1 - torch.exp(torch.linspace(torch.log(1/32), torch.log(1/512), self.num_attention_heads))).view(self.num_attention_heads, 1, 1) ** (torch.arange(sequence_length).unsqueeze(1))).float().unsqueeze(0) + return nn.Parameter(D, requires_grad=False) + + def _get_D2(self, sequence_length): + D = 1/((1 - torch.exp(torch.linspace(torch.log(1/32), torch.log(1/512), self.num_attention_heads))).view(self.num_attention_heads, 1, 1) ** (torch.arange(sequence_length).unsqueeze(1))).float().unsqueeze(0) + + return nn.Parameter(D, requires_grad=False) + + def _get_mask(self, sequence_length): + n = torch.arange(sequence_length).unsqueeze(1) + m = torch.arange(sequence_length).unsqueeze(0) + + M = torch.ones(self.num_attention_heads).view(self.num_attention_heads, 1, 1)*(n >= m).float() + + return M + + def _build_projection(self): + icdf_w = torch.distributions.Normal(0, 1).icdf(torch.diag_embed(torch.diag(torch.rand(self.head_size, self.head_size)))) + icdf_w = torch.where(torch.isinf(icdf_w), torch.full_like(icdf_w, 0), icdf_w) + C = dct(np.eye(self.head_size, self.head_size), axis=0,norm='ortho') + C = torch.from_numpy(C).type(torch.FloatTensor) + return nn.Parameter((C @ icdf_w).contiguous(), requires_grad=False) + def forward( self, hidden_states: torch.FloatTensor, @@ -212,7 +251,8 @@ # Reshape outputs attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) - attn_output = self.dense(attn_output) + attn_output = self.group_norm(attn_output.reshape(-1, self.v_dim)).reshape(attn_output.shape) + attn_output = self.dense(self.swish(hidden_states @ self.W_G) * attn_output) outputs = (attn_output, present) if output_attentions: @@ -256,36 +296,16 @@ self._init_bias(key_length, device=key.device) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] - query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) - key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) - attn_scores = torch.zeros( - batch_size * num_attention_heads, - query_length, - key_length, - dtype=query.dtype, - device=key.device, - ) - attn_scores = torch.baddbmm( - attn_scores, - query, - key.transpose(1, 2), - beta=1.0, - alpha=self.norm_factor, - ) - attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + query = nn.functional.softmax(query@self.proj_matrix, dim=-1) + key = nn.functional.softmax(key@self.proj_matrix, dim=-1) + query = query*self.D1[:,:,:query_length,:] + key = key*self.D2[:,:,:key_length,:] - mask_value = torch.finfo(attn_scores.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) - attn_scores = torch.where(causal_mask, attn_scores, mask_value) + attn_scores = torch.matmul(query, key.transpose(2, 3)) - if attention_mask is not None: - # Apply the attention mask - attn_scores = attn_scores + attention_mask + attn_scores = attn_scores * self.mask[:,:,:query_length,:key_length].to(attn_scores.device,dtype=attn_scores.dtype) - attn_weights = nn.functional.softmax(attn_scores, dim=-1) - attn_weights = attn_weights.to(value.dtype) + attn_weights = attn_scores.to(value.dtype) # Mask heads if we want to if head_mask is not None: @@ -667,14 +687,14 @@ class GPTNeoXLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, gamma): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_dropout = nn.Dropout(config.hidden_dropout) self.post_mlp_dropout = nn.Dropout(config.hidden_dropout) - self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config, gamma) self.mlp = GPTNeoXMLP(config) def forward( @@ -787,7 +807,8 @@ self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) self.emb_dropout = nn.Dropout(config.hidden_dropout) - self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) + self.gammas = (1 - torch.exp(torch.linspace(torch.log(1/32), torch.log(1/512), config.num_hidden_layers))).detach().cpu().tolist() + self.layers = nn.ModuleList([GPTNeoXLayer(config, gamma) for gamma in self.gammas]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
As for DCT, perhaps could refer to https://github.com/zh217/torch-dct/blob/master/torch_dct/_dct.py
Here is a diff version compared with https://github.com/huggingface/transformers/commit/fed27ffc7ec62837dca9bbfc83442eb3678ee026
As for DCT, perhaps could refer to https://github.com/zh217/torch-dct/blob/master/torch_dct/_dct.py