ztxz16 / fastllm

纯c++的全平台llm加速库,支持python调用,chatglm-6B级模型单卡可达10000+token / s,支持glm, llama, moss基座,手机端流畅运行
Apache License 2.0
3.3k stars 336 forks source link

请求支持Grouped Query Attention #416

Closed TylunasLi closed 7 months ago

TylunasLi commented 8 months ago

目前,较新或尺寸较大的LLaMA架构的模型都使用了Grouped Query Attention以加快速度,但是fastllm还不支持。 尝试做了以下简单修改,shape似乎是对的,但推理结果不对。希望作者能够在此基础上继续支持Grouped Query Attention。

diff --git a/include/models/llama.h b/include/models/llama.h
index d15c7f1..5cb18ea 100644
--- a/include/models/llama.h
+++ b/include/models/llama.h
@@ -15,6 +15,8 @@
     public:
         LlamaModel (); // 构造函数

+        virtual void InitParams(); // 初始化参数信息
+
         // 推理
         virtual int Forward(
                 const Data &inputIds,
@@ -65,6 +67,8 @@
         virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt

         virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history
+
+        int num_key_value_heads = num_attention_heads;
     };
 }

diff --git a/src/models/llama.cpp b/src/models/llama.cpp
index bf7f529..c35cc49 100644
--- a/src/models/llama.cpp
+++ b/src/models/llama.cpp
@@ -81,6 +81,15 @@
         weight.embeddingNames.insert("model.embed_tokens.weight");
     }

+    void LlamaModel::InitParams() {
+        basellm::InitParams();
+        num_key_value_heads = num_attention_heads;
+        if (this->weight.dicts.find("num_key_value_heads") != this->weight.dicts.end()) {
+            num_key_value_heads = atoi(this->weight.dicts["num_key_value_heads"].c_str());
+        }
+        head_dim = embed_dim / num_attention_heads;
+    }
+
     int LlamaModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask,
                             const fastllm::Data &positionIds, std::vector<std::pair<Data, Data>> &pastKeyValues,
                             const GenerationConfig &generationConfig, const LastTokensManager &lastTokens,
@@ -124,7 +133,7 @@
                 Linear(attenInput, weight[vWeightName], Data(), v);
             }

-            std::vector <int> qkvSize = {bsz, seqlen, num_attention_heads, -1};
+            std::vector <int> qkvSize = {bsz, seqlen, -1, head_dim };
             q.Reshape(qkvSize);
             k.Reshape(qkvSize);
             v.Reshape(qkvSize);
@@ -134,8 +143,9 @@
                 fastllm::LlamaRotatePosition2D(k, positionIds, sinData, cosData, rotary_dim);
             }

-            qkvSize = {bsz * seqlen, num_attention_heads, -1};
+            qkvSize = {-1, num_key_value_heads, head_dim};
             q.Reshape(qkvSize);
+            qkvSize = {bsz * seqlen, -1, head_dim};
             k.Reshape(qkvSize);
             v.Reshape(qkvSize);

@@ -286,7 +296,7 @@
                 Linear(attenInput, weight[vWeightName], Data(), v);
             }

-            std::vector <int> qkvSize = {bsz, seqlen, num_attention_heads, -1};
+            std::vector<int> qkvSize = {bsz, seqlen, -1, head_dim};
             q.Reshape(qkvSize);
             k.Reshape(qkvSize);
             v.Reshape(qkvSize);
@@ -300,8 +310,9 @@
             PermuteSelf(k, {0, 2, 1, 3});
             PermuteSelf(v, {0, 2, 1, 3});

-            qkvSize = {bsz * num_attention_heads, seqlen, -1};
+            qkvSize = {bsz * num_key_value_heads, -1, head_dim};
             q.Reshape(qkvSize);
+            qkvSize = {bsz * num_key_value_heads, seqlen, -1};
             k.Reshape(qkvSize);
             v.Reshape(qkvSize);

@@ -472,7 +483,7 @@
             for (int b = 0; b < batch; b++) {
                 auto &q = curQs[b], &k = curKs[b], &v = curVs[b];

-                std::vector<int> qkvSize = {bsz, seqLens[b], num_attention_heads, -1};
+                std::vector<int> qkvSize = {bsz, seqLens[b], -1, head_dim };
                 q.Reshape(qkvSize);
                 k.Reshape(qkvSize);
                 v.Reshape(qkvSize);
@@ -486,8 +497,9 @@
                 PermuteSelf(k, {0, 2, 1, 3});
                 PermuteSelf(v, {0, 2, 1, 3});

-                qkvSize = {bsz * num_attention_heads, seqLens[b], -1};
+                qkvSize = {bsz * num_key_value_heads, -1, head_dim};
                 q.Reshape(qkvSize);
+                qkvSize = {bsz * num_key_value_heads, seqLens[b], -1};
                 k.Reshape(qkvSize);
                 v.Reshape(qkvSize);