AlibabaResearch / DAMO-ConvAI

DAMO-ConvAI: The official repository which contains the codebase for Alibaba DAMO Conversational AI.
MIT License
1.21k stars 186 forks source link

dataloader question. #44

Closed byan19 closed 1 year ago

byan19 commented 1 year ago

Hi, thanks for sharing the code for the brilliant work, Deep Thinking. But the data loader loop in the kv_iter loop, which I believe represents the total optimisation steps, really makes me confused. In the few-short learning setting, only k examples are given, and in your case, it should be exemplar_str. But why is there an extra data loader applied? Does it mean the proposed deep thinking process requiring extra data?

Yangjiaxi commented 1 year ago

Hi, byan! Thank you for your attention and interest in the code.

The outer loop, kv_iter, represents the number of forward tuning steps, while the inner loop, loader, corresponds to the evaluation (inference) process.

The exemplar_str variable, it indeed holds k examples. The few-shot inference setting is then accomplished by utilizing Key-Value matrices that have been optimized for idx steps (idx <= args.kv_iter). The inner loader loop takes these matrices as past_key_value and performs inference for each test input.

Here's a separated implementation.

# (1) Deep-Thinking for `num_steps` steps.
meta_optim.init()
for idx in range(num_steps):
    exemplar_kv = meta_optim.step(exemplar_input_ids)  # perform forward tuning

# (2) Inference using optimized Key-Value matrices.
generated_info = []  # question * [choice0_prob, choice1_prob]
for batch_input in tqdm(loader, desc=f"idx={idx}"):
    batch_input = [[e.cuda() for e in batch_choice] for batch_choice in batch_input]
    batch_output = do_infer_probs(
        exemplar_kv,
        exemplar_attn_mask.unsqueeze(0),
        batch_input,
    )  # [batch_of_choice0, batch_of_choice1, ...]
    zipped_logprobs = list(zip(*batch_output))  # batch * (choice0, choice1, ...)
    generated_info.extend(zipped_logprobs)

full_info, metric = task_agent.post_process(generated_info, metric_output=False)
metric_s = json.dumps(metric, indent=None)
logger.info(f"Iter={idx+1: <3} | {metric_s}")