Closed cabisarri closed 3 months ago
你好,感谢对我们工作的关注!
实际上作为入参的 exemplar_attn_kv
还没引入 batch>1 的概念,咱们来捋一下它产生的顺序。
(1) L148-L155产生了示例样本的字符串,多个示例样本会直接进行字符串拼接 (2) L172对示例样本字符串做tokenization,会得到形状 [L_ex] 的 LongTensor, L_ex为示例样本的token 数量 (3) L184,得到计算后的 kv-matrices,形状是 num_layers 2(K and V) [1, num_heads, L_ex, num_hidden_per_head],注意这里dim0=1,即batch=1,没有引入batch>1的概念 (4) L48,对 kv-matrices的batch维度进行expand,来匹配batch_input_ids的batchsize(=bs,见L36),具体来说 layer_k.expand((bs, -1, -1, -1)) 从 [1, num_heads, L_ex, num_hidden_per_head]变成了 [bs, num_heads, L_ex, num_hidden_per_head],此时 kv 的 batch 就匹配上了 input (5) L50-L54将 expanded kv 送进模型
在结合 example 对应的 kv 和 mask, 这里只能让
exemplar_attn_kv
的 batch_size 等同于batched_choices_input
的 batch size , 要不然没法生成merged_attn_mask
和expand_exemplar_attn_kv,
去及推理获得 batched_logits (代码如下)。基于实际场景,有可能是 80 个 example 用于做更好的 ICL, 而推理过程中 可能实际是 batch_size = 8,请问你们是怎么做的?