ymcui / Chinese-LLaMA-Alpaca

中文LLaMA&Alpaca大语言模型+本地CPU/GPU训练部署 (Chinese LLaMA & Alpaca LLMs)
https://github.com/ymcui/Chinese-LLaMA-Alpaca/wiki
Apache License 2.0
18.23k stars 1.86k forks source link

Extend context size without fine-tuning #705

Closed airaria closed 1 year ago

airaria commented 1 year ago

Description

Update: We find that NTK method mentioned in this Reddit post outperforms Position Interpolation up to a context size of at least 6K. Thus, with replace the implementation of PI with NTK method.

In addition, we use an empirical formula to set $\alpha$ adaptively given the input size, so that we could avoid hyperparameter tuning, and the method can be applied to different context sizes.

The following is the perplexity of Chinese-LLaMA-Plus-7B on a test set: Context size 512 1024 2048 3072 4096 5120 6144
baseline 11.4 10.98 10.98 173.5 - - -
Position Interpolation 11.4 10.98 10.98 11.47 12.42 14.44 17.86
Adaptive NTK (this PR) 11.4 10.98 10.98 11.05 11.05 11.40 12.57

Even though Chinese-LLaMA-Plus-7B has been trained with input_length of 512, its context size can be extend to 5K~6K without significantly increasing the perplexity

Users only need to add the following lines to the beginning of the python code:

import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
    self.dim = dim
    self.base = base
    old_init(self, dim, max_position_embeddings, base, device)

def adaptive_ntk_forward(self, x, seq_len=None):
    if seq_len > self.max_seq_len_cached:
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        inv_freq = self.inv_freq
        dim = self.dim
        alpha = seq_len / 1024 - 1
        base = self.base * alpha ** (dim / (dim-2))
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

        freqs = torch.einsum("i,j->ij", t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
        cos_cached = emb.cos()[None, None, :, :]
        sin_cached = emb.sin()[None, None, :, :]
        return (
            cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
        )
    return (
        self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
    )
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init

We keep the old implementation below for others' reference.


implementation of Position Interpolation (deprecated)

Description

We implement the Position Interpolation (proposed in the paper EXTENDING CONTEXT WINDOW OF LARGE LAN- GUAGE MODELS VIA POSITION INTERPOLATION and in the blog) for using LLaMA with Transformers.

We find that the method can be used out-of-the box even without training the model with long context size. The following is the perplexity of Chinese-LLaMA-Plus-7B on a test set: Context size 512 1024 2048 3072 4096 5120
Perplexity 11.4 11.0 11.0 11.5 12.4 15.6

Note that even though Chinese-LLaMA-Plus-7B has been trained with input_length of 512, its context window size can be extend to 4096 without significantly increasing the perplexity

Users only need to add the following lines to the beginning of the python code:

import transformers
def pi_forward(self, x, seq_len=None):
    if seq_len > self.max_seq_len_cached: # seq_len > 2048
        print(f"Perform position interpolation for length {seq_len}")
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        scale = self.max_seq_len_cached / seq_len
        t *= scale
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
        cos_cached = emb.cos()[None, None, :, :]
        sin_cached = emb.sin()[None, None, :, :]
        return (
            cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
        )
    return (
        self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
    )
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = pi_forward

If seq_len<=2048, the behavior is not changed; If seq_len>2048, the Position Interpolation is performed and the context size is extend to seq_len.

tkone2018 commented 1 year ago

@ymcui @airaria 请问这个是可以直接拿来用的是吗

xyfZzz commented 1 year ago

只在推理阶段可以使用吗?在微调阶段有加上ntk吗?

airaria commented 1 year ago

只在推理阶段可以使用吗?在微调阶段有加上ntk吗?

目前已发布的模型中在微调阶段没有加上NTK。

xyfZzz commented 1 year ago

只在推理阶段可以使用吗?在微调阶段有加上ntk吗?

目前已发布的模型中在微调阶段没有加上NTK。

微调代码上目前有加上吗?这样我们可以在自己的数据上使用ntk进行微调