OpenBMB / MiniCPM

MiniCPM-2B: An end-side LLM outperforming Llama2-13B.
Apache License 2.0
4.66k stars 334 forks source link

[Bug]: MiniCPMAttention类is_causal不起作用 #90

Closed bokesyo closed 1 month ago

bokesyo commented 5 months ago

Is there an existing issue ? / 是否已有相关的 issue ?

Describe the bug / 描述这个 bug

https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16/blob/main/modeling_minicpm.py 中,有:

`

class MiniCPMAttention(nn.Module):   | """Multi-headed attention from 'Attention Is All You Need' paper"""   |     | def init(self, config: MiniCPMConfig, layer_idx: Optional[int] = None):   | super().init()   | self.config = config   | self.layer_idx = layer_idx   | if layer_idx is None:   | logger.warning_once(   | f"Instantiating {self.class.name} without passing layer_idx is not recommended and will "   | "to errors during the forward call, if caching is used. Please make sure to provide a layer_idx "   | "when creating this class."   | )   |     | self.attention_dropout = config.attention_dropout   | self.hidden_size = config.hidden_size   | self.num_heads = config.num_attention_heads   | self.head_dim = self.hidden_size // self.num_heads   | self.num_key_value_heads = config.num_key_value_heads   | self.num_key_value_groups = self.num_heads // self.num_key_value_heads   | self.max_position_embeddings = config.max_position_embeddings   | self.rope_theta = config.rope_theta   | self.is_causal = True   |     | if (self.head_dim self.num_heads) != self.hidden_size:   | raise ValueError(   | f"hidden_size must be divisible by num_heads (got hidden_size: {self.hidden_size}"   | f" and num_heads: {self.num_heads})."   | )   |     | self.q_proj = nn.Linear(self.hidden_size, self.num_heads self.head_dim, bias=config.attention_bias)   | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads self.head_dim, bias=config.attention_bias)   | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads self.head_dim, bias=config.attention_bias)   | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)   | self._init_rope()   |     | def _init_rope(self):   | if self.config.rope_scaling is None:   | self.rotary_emb = MiniCPMRotaryEmbedding(   | self.head_dim,   | max_position_embeddings=self.max_position_embeddings,   | base=self.rope_theta,   | )   | else:   | scaling_type = self.config.rope_scaling["type"]   | scaling_factor = self.config.rope_scaling["factor"]   | if scaling_type == "linear":   | self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(   | self.head_dim,   | max_position_embeddings=self.max_position_embeddings,   | scaling_factor=scaling_factor,   | base=self.rope_theta,   | )   | elif scaling_type == "dynamic":   | self.rotary_emb = MiniCPMDynamicNTKScalingRotaryEmbedding(   | self.head_dim,   | max_position_embeddings=self.max_position_embeddings,   | scaling_factor=scaling_factor,   | base=self.rope_theta,   | )   | else:   | raise ValueError(f"Unknown RoPE scaling type {scaling_type}")   |     | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):   | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()   |     | def forward(   | self,   | hidden_states: torch.Tensor,   | attention_mask: Optional[torch.Tensor] = None,   | position_ids: Optional[torch.LongTensor] = None,   | past_key_value: Optional[Cache] = None,   | output_attentions: bool = False,   | use_cache: bool = False,   | *kwargs,   | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:   | if "padding_mask" in kwargs:   | warnings.warn(   | "Passing padding_mask is deprecated and will be removed in v4.37. Please make sure use attention_mask instead.`"   | )   |     | bsz, qlen, = hidden_states.size()   |     | if self.config.pretraining_tp > 1:   | key_value_slicing = (self.num_key_value_heads self.head_dim) // self.config.pretraining_tp   | query_slices = self.q_proj.weight.split(   | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0   | )   | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)   | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)   |     | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]   | query_states = torch.cat(query_states, dim=-1)   |     | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]   | key_states = torch.cat(key_states, dim=-1)   |     | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]   | value_states = torch.cat(value_states, dim=-1)   |     | else:   | query_states = self.q_proj(hidden_states)   | key_states = self.k_proj(hidden_states)   | value_states = self.v_proj(hidden_states)   |     | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)   | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)   | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)   |     | kv_seq_len = key_states.shape[-2]   | if past_key_value is not None:   | if self.layer_idx is None:   | raise ValueError(   | f"The cache structure has changed since version v4.36. If you are using {self.class.name} "   | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "   | "with a layer index."   | )   | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)   | cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)   |     | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)   |     | if past_key_value is not None:   | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models   | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)   |     | key_states = repeat_kv(key_states, self.num_key_value_groups)   | value_states = repeat_kv(value_states, self.num_key_value_groups)   |     | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)   | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):   | raise ValueError(   | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"   | f" {attn_weights.size()}"   | )   |     | if attention_mask is not None:   | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):   | raise ValueError(   | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"   | )   | attn_weights = attn_weights + attention_mask   |     | # upcast attention to fp32   | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)   | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)   | attn_output = torch.matmul(attn_weights, value_states)   |     | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):   | raise ValueError(   | f"attn_output should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"   | f" {attn_output.size()}"   | )   |     | attn_output = attn_output.transpose(1, 2).contiguous()   |     | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)   |     | if self.config.pretraining_tp > 1:   | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)   | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)   | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])   | else:   | attn_output = self.o_proj(attn_output)   |     | if not output_attentions:   | attn_weights = None   |     | return attn_output, attn_weights, past_key_value  


`

可见,虽然定义了is_causal参数在init中,但实际forward并没有使用is_causal,只能在启用flash_attention_2下考虑到,所以这是个bug,如果没有flash_attention,并且希望使用双向注意力,就会出现问题。

To Reproduce / 如何复现

使用eager attention,且指定is_causal=True

Expected behavior / 期望的结果

应该得到双向注意力结果,但实际上只有causal注意力结果

Screenshots / 截图

No response

Environment / 环境

- OS: [e.g. Ubuntu 18.04]
- Pytorch: [e.g. torch 2.0.0]
- CUDA: [e.g. CUDA 11.6]
- Device: [e.g. A100-SXM-80G]

Additional context / 其他信息

No response

LDLINGLINGLING commented 2 months ago

感谢您的贡献,刚查了一下代码确实is_causal在minicpm中仅作用于flash_attention。