ztxz16 / fastllm

纯c++的全平台llm加速库,支持python调用,chatglm-6B级模型单卡可达10000+token / s,支持glm, llama, moss基座,手机端流畅运行
Apache License 2.0
3.31k stars 338 forks source link

chatglm3-6b-32k的位置编码扩展方式似乎与chatglm2-6b-32k不同 #374

Closed TylunasLi closed 9 months ago

TylunasLi commented 11 months ago

THUDM/chatglm2-6b-32k

class RotaryEmbedding(nn.Module):
    def forward_impl(
            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
    ):
        """Enhanced Transformer with Rotary Position Embedding.
        """
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=dtype, device=device) / self.rope_ratio

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()

        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # this is to mimic the behaviour of complex32, else we will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        return cache

THUDM/chatglm3-6b-32k


class RotaryEmbedding(nn.Module):
    def forward_impl(
            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
    ):
        """Enhanced Transformer with Rotary Position Embedding.
        """
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        base = base * self.rope_ratio
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()

        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # this is to mimic the behaviour of complex32, else we will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        return cache

位置编码扩展中,rpoe_ratio应用的位置似乎有差异,可能需要作相应修改

2496289471 commented 11 months ago

chatglm3-6b-32k模型可以转换为flm,但无论量化与否都会出现语无伦次的问题。下面是chatglm3 32k和8k效果对比

2496289471 commented 11 months ago

chatglm3-6b-32k模型可以转换为flm,但无论量化与否都会出现语无伦次的问题。下面是chatglm3 32k和8k效果对比

2496289471 commented 11 months ago

image

2496289471 commented 11 months ago

8k模型效果正常 image