ymcui / Chinese-LLaMA-Alpaca-2

中文LLaMA-2 & Alpaca-2大模型二期项目 + 64K超长上下文模型 (Chinese LLaMA-2 & Alpaca-2 LLMs with 64K long context models)
Apache License 2.0
7.06k stars 578 forks source link

长度外推的三种方式得到的answer竟一模一样? #455

Closed IT-five closed 10 months ago

IT-five commented 10 months ago

提交前必须检查以下项目

问题类型

模型推理

基础模型

Others

操作系统

Linux

详细描述问题

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))

        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
        self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        # print(f"当前句子长度:{seq_len},最大长度要求:{self.max_seq_len_cached}")
        if seq_len > self.max_seq_len_cached:

            print(f"seq_len:{seq_len}对比max_seq_len_cached:{self.max_seq_len_cached},此时直接外推")

            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
            self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
        elif self.cos_cached.device != x.device:
            self.cos_cached = self.cos_cached.to(x.device)
            self.sin_cached = self.sin_cached.to(x.device)  
        return (
            self.cos_cached[:, :, :seq_len, ...],
            self.sin_cached[:, :, :seq_len, ...],
        )

class LinearScalingRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
        self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            print(f"seq_len:{seq_len}对比max_seq_len_cached:{self.max_seq_len_cached},此时线性插值")
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
            scale = self.max_seq_len_cached / seq_len
            t *= scale
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
            self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
        elif self.cos_cached.device != x.device:
            self.cos_cached = self.cos_cached.to(x.device)
            self.sin_cached = self.sin_cached.to(x.device)  
        return (
            self.cos_cached[:, :, :seq_len, ...],
            self.sin_cached[:, :, :seq_len, ...],
        )
class NTKLinearScalingRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=2.0):
        super().__init__()
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
        self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            print(f"seq_len:{seq_len}对比max_seq_len_cached:{self.max_seq_len_cached},此时NTK")
            self.max_seq_len_cached = seq_len
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))  #NTK扩展方式直接对base进行缩放
            self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
            t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
            self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
        elif self.cos_cached.device != x.device:
            self.cos_cached = self.cos_cached.to(x.device)
            self.sin_cached = self.sin_cached.to(x.device)  
        return (
            self.cos_cached[:, :, :seq_len, ...],
            self.sin_cached[:, :, :seq_len, ...],
        )

我在使用上述三种长度外推方式,在baichuan2-7b-chat上,并在modeling_baichuan中实现,结果发现在longbench_e数据集上得到的answer竟然一模一样,pred.py中,我将截断的逻辑删除,其余并没改变,推理的配置参数也未做改变,请问这是什么原因造成的?

依赖情况(代码类问题务必提供)

bitsandbytes                  0.41.1
open-clip-torch               2.20.0
peft                          0.5.0
pytorch-lightning             1.7.7
pytorch-metric-learning       2.3.0
pytorch-wavelets              1.3.0
pytorch-wpe                   0.0.1
pytorch3d                     0.7.4
rotary-embedding-torch        0.3.0
sentencepiece                 0.1.99
taming-transformers-rom1504   0.0.6
torch                         2.0.1+cu118
torch-complex                 0.4.3
torch-scatter                 2.1.1
torchaudio                    2.0.2+cu118
torchmetrics                  0.11.4
torchsummary                  1.5.1
torchvision                   0.15.2+cu118
transformers                  4.34.1
transformers-stream-generator 0.0.4

运行日志或截图

WeChat89b4b505bd12683354fb7e3977fc1531

ymcui commented 10 months ago

Baichuan的相关使用问题麻烦去对应项目提问,谢谢。