AlibabaResearch / DAMO-ConvAI

DAMO-ConvAI: The official repository which contains the codebase for Alibaba DAMO Conversational AI.
MIT License
1.21k stars 186 forks source link

请问提供的代码是否可以直接用于Chatglm-6b模型,以及是否支持自定义数据集 #50

Closed Siegfried-qgf closed 1 year ago

Yangjiaxi commented 1 year ago

你好,deep-thinking 依赖 past_key_values 与 use_cache 这两个 forward 参数,因此支持这两个参数的模型都可以轻松使用( BLOOM系列模型的 key / value 矩阵形状定义与 GPT-series 稍有不同,需要做点修改)。对于 GLM 系列模型,我们目前没有这方面的实验,不过据了解新公布的 ChatGLM-2 换用了 decoder-only 结构,理应可以支持。

对于数据集,如果是分类与多项选择,可以参考 tasks/ 文件夹里的实现;如果是生成式任务,还请你修改 dataset class 与主文件中的结果生成部分。

ke-01 commented 1 year ago

请问论文里面的元梯度,即论文公式(2)中的B部分,每一层的元梯度,是这里的past_key_values吗?比如GPT2中的12层,每一层的元梯度,是对应输入给GPT2后得到的past_key_values的12个元素吗?

Yangjiaxi commented 1 year ago

请问论文里面的元梯度,即论文公式(2)中的B部分,每一层的元梯度,是这里的past_key_values吗?比如GPT2中的12层,每一层的元梯度,是对应输入给GPT2后得到的past_key_values的12个元素吗?

@ke-01 你好,论文公式(2)中的元梯度与代码实现中的past_key_values 的关系我来解释一下:

ke-01 commented 1 year ago

请问论文里面的元梯度,即论文公式(2)中的B部分,每一层的元梯度,是这里的past_key_values吗?比如GPT2中的12层,每一层的元梯度,是对应输入给GPT2后得到的past_key_values的12个元素吗?

@ke-01 你好,论文公式(2)中的元梯度与代码实现中的past_key_values 的关系我来解释一下:

  • 论文公式(2)B部分描述的是示例样本对待测样本的贡献,这里的元梯度是一个抽象层面的概念,表示示例样本带来的信息量,我们把它看做元梯度
  • 在代码实现中,past_key_values出现在了 (1) 使用 use_cache=True将Attention模块的 K,V取回;(2) 在 deep-thinking 阶段将混合(update过程)后的 K~,V~ 以past_key_values参数的形式送入模型;(3) 在 inference 阶段,优化了 T 次的 K~T,V~T 同样以past_key_values参数的形式送入模型
  • 因此文中的元梯度并不是模型输出past_key_values,而是我们手动混合后,送入模型的past_key_values

@Yangjiaxi 非常感谢您的回答! 但我还想问您几个小问题 如您所说,"文中的元梯度并不是模型输出的past_key_values,而是我们手动混合后,送入模型的past_key_values" 我理解的在 deep-thinking 阶段的混合为cur_kv.append([layer_k[:, :, -L:, :], layer_v[:, :, -L:, :]]) # kv @ (new_ctx)https://github.com/AlibabaResearch/DAMO-ConvAI/blob/2a2830d91c9713979ec5fde9daa510ffc6c4b446/deep-thinking/models/meta_optimizer.py#L66元梯度即混合后的cur_kv,那么 1.在第一轮update中,输入给模型的past_key_values为None时,元梯度是否就是模型输出的past_key_values? 2.第二轮以及之后的update,为什么输入给模型的是2*len(ids)的mask,取后len(ids)的k,v作为cur_kv? 3.past_key_values是两部分组成,是否可以理解为K,V矩阵即代表了元梯度,并不需要像公式(2)B部分那样将两部分进行相乘?

Yangjiaxi commented 1 year ago

@ke-01 很好的问题!感谢您对我们工作的细致关注!

  1. 是的。第一轮输入给模型的 past_key_valuesNone,这时模型输出的 past_key_values 如 68-69行所示,不走update,直接返回,即元梯度就是输出的 past_key_values
  2. attention_mask的长度是 transformers 模型要求的,例如可见 OPT.OPTDecoder.forward源码;后 len(ids) 长度的 $K, V$ 对应的是 这一轮 的文本,因此可以写出 $\tilde{K} t= \texttt{upd}(\tilde{K}{t-1}, \cdots)$,这实际上 1) 模拟拼接多次 context 以达到效果提升 2) 模拟的拼接的同时规避了超出input window的问题 3) 可以看做一阶马尔可夫链,对简化问题有很大的帮助
  3. 是的。元梯度是一种抽象的描述方法,具体就是靠操纵 $K, V$ 矩阵实现的,正如文中 3.1 节描述
ke-01 commented 1 year ago

@Yangjiaxi 感谢您详细的回答!您的解释帮助我更好地理解了这些问题!🤗