ymcui / Chinese-LLaMA-Alpaca

中文LLaMA&Alpaca大语言模型+本地CPU/GPU训练部署 (Chinese LLaMA & Alpaca LLMs)
https://github.com/ymcui/Chinese-LLaMA-Alpaca/wiki
Apache License 2.0
18.23k stars 1.86k forks source link

Add patches for memory_efficient_attention and NTK scaling #743

Closed airaria closed 1 year ago

airaria commented 1 year ago

Description

What does this PR do?

Usage

alpha=2.0 # alpha can be a float, a string representing a float,  or 'auto'
use_memory_efficient_attention=True # True or False
store_kv_before_rope=False # True or False

# The following code should be placed before model initialization
from patches import apply_attention_patch, apply_ntk_scaling_patch
apply_attention_patch(
    use_memory_efficient_attention=use_memory_efficient_attention,
    store_kv_before_rope=store_kv_before_rope
)
apply_ntk_scaling_patch(alpha=alpha)

Parameters

Advices

IT-five commented 10 months ago

我想请问一下,源码中如下,对长度超过max_length的进行了截断,但在NTK实现里又要求"if seq_len > self.max_seq_len_cached:",那是不是意味着永远不会超过self.max_seq_len_cached,那怎么支持NTK外推上下文呢?

if len(tokenized_prompt) > max_length:
            half = int(max_length/2)
            prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)