AlibabaResearch / DAMO-ConvAI

DAMO-ConvAI: The official repository which contains the codebase for Alibaba DAMO Conversational AI.
MIT License
1.08k stars 176 forks source link

Deep-Thinking的一个扩展疑问。 #126

Closed cabisarri closed 3 months ago

cabisarri commented 4 months ago

在结合 example 对应的 kv 和 mask, 这里只能让 exemplar_attn_kv 的 batch_size 等同于 batched_choices_input 的 batch size , 要不然没法生成 merged_attn_maskexpand_exemplar_attn_kv, 去及推理获得 batched_logits (代码如下)。基于实际场景,有可能是 80 个 example 用于做更好的 ICL, 而推理过程中 可能实际是 batch_size = 8,请问你们是怎么做的?

@torch.no_grad()
def do_infer_probs(exemplar_attn_kv, exemplar_attn_mask, batched_choices_input):
    batched_choices_logprobs = []
    for batched_one_choice_input in batched_choices_input:
        batch_input_ids, batch_attention_mask, batch_choice_start, batch_choice_end = batched_one_choice_input
        bs = len(batch_input_ids)

        merged_attn_mask = torch.cat((exemplar_attn_mask.expand(bs, -1), batch_attention_mask), dim=1)
        if args.model_type == "bloom":
            # [B*#Heads, Length, Hidden]
            def _expand(t, target_size):
                _bs, _head, _len, _hidden = 1, *t.size()
                return t.reshape(_bs, _head, _len, _hidden).expand(target_size * _bs, -1, -1, -1).reshape(target_size * _bs * _head, _len, _hidden)

            expand_exemplar_attn_kv = [[_expand(layer_k, bs), _expand(layer_v, bs)] for layer_k, layer_v in exemplar_attn_kv]
        else:
            # [B, #Heads, Length, Hidden]
            expand_exemplar_attn_kv = [[layer_k.expand((bs, -1, -1, -1)), layer_v.expand((bs, -1, -1, -1))] for layer_k, layer_v in exemplar_attn_kv]

        batched_logits = model(
            input_ids=batch_input_ids,  # [B, L']
            attention_mask=merged_attn_mask,  # [B, L + L']
            past_key_values=expand_exemplar_attn_kv,  # num_layers * 2 * [B, num_heads, L, H]
        ).logits
        batched_output = F.log_softmax(batched_logits, dim=-1)  # [B, L', Vocab]

        batched_one_choice_logprobs = []
        for input_ids, choice_start, choice_end, lm_logprobs in zip(batch_input_ids, batch_choice_start, batch_choice_end, batched_output):
            choice_tokens = input_ids[choice_start:choice_end].unsqueeze(1)  # [L, 1]
            choice_logprobs = lm_logprobs[choice_start - 1 : choice_end - 1]  # [L, Vocab]

            extracted = torch.gather(choice_logprobs, -1, choice_tokens).squeeze(-1)

            choice_length = choice_end - choice_start
            lm_log_p = torch.sum(extracted).item()
            norm_lm_log_p = (lm_log_p / choice_length).item()

            choice_lm_info = {"lm_log_p": lm_log_p, "norm_lm_log_p": norm_lm_log_p}
            batched_one_choice_logprobs.append(choice_lm_info)
        batched_choices_logprobs.append(batched_one_choice_logprobs)
    return batched_choices_logprobs
Yangjiaxi commented 4 months ago

你好,感谢对我们工作的关注!

实际上作为入参的 exemplar_attn_kv 还没引入 batch>1 的概念,咱们来捋一下它产生的顺序。

(1) L148-L155产生了示例样本的字符串,多个示例样本会直接进行字符串拼接 (2) L172对示例样本字符串做tokenization,会得到形状 [L_ex] 的 LongTensor, L_ex为示例样本的token 数量 (3) L184,得到计算后的 kv-matrices,形状是 num_layers 2(K and V) [1, num_heads, L_ex, num_hidden_per_head],注意这里dim0=1,即batch=1,没有引入batch>1的概念 (4) L48,对 kv-matrices的batch维度进行expand,来匹配batch_input_ids的batchsize(=bs,见L36),具体来说 layer_k.expand((bs, -1, -1, -1)) 从 [1, num_heads, L_ex, num_hidden_per_head]变成了 [bs, num_heads, L_ex, num_hidden_per_head],此时 kv 的 batch 就匹配上了 input (5) L50-L54将 expanded kv 送进模型