Closed nullgogo closed 2 years ago
不好意思, 没看到你的问题. interact.py的代码是针对vanilla模型设计的. 如果要和strategy模型交互, 需要改一下代码.
我做了如下的修改,麻烦帮忙看一下这样改是否合理呢,感谢
# ...
# ...
id2strategy = {
0: "Question",
1: "Restatement or Paraphrasing",
2: "Reflection of feelings",
3: "Self-disclosure",
4: "Affirmation and Reassurance",
5: "Providing Suggestions",
6: "Information",
7: "Others"
}
# 163 generate response
history['dialog'].append({ # dummy tgt
'text': 'n/a',
'speaker': 'sys',
"strategy": "Others" # 伪策略
})
inputs = inputter.convert_data_to_inputs(history, toker, **dataloader_kwargs)
inputs = inputs[-1:]
features = inputter.convert_inputs_to_features(inputs, toker, **dataloader_kwargs)
batch = inputter.prepare_infer_batch(features, toker)
batch = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in batch.items()}
batch.update(generation_kwargs)
encoded_info, generations = model.generate(**batch)
out = generations[0].tolist()
out = cut_seq_to_eos(out, eos)
text = toker.decode(out).encode('ascii', 'ignore').decode('ascii').strip()
strat_id_out = encoded_info['pred_strat_id_top3'].tolist()[0][0] # 取top1 策略id
strategy = id2strategy[strat_id_out]
print(" AI: " + "[" + strategy + "]" + text)
history['dialog'].pop()
history['dialog'].append({
'text': text,
'speaker': 'sys',
'strategy': strategy
})
看起来合理. 能正常运行就可以
可以正常运行,有一个问题是当对话轮次较少,即inputs文本 tokenize之后的总长度小于max_input_length时,每一次AI回复的strategy不变。
把这行代码改一下:
batch = inputter.prepare_infer_batch(features, toker) -> batch = inputter.prepare_infer_batch(features, toker, interact=True)
好像还有这个问题。 我的理解是:pred_strat_id是根据strat_dialogpt.py中predict_strategy方法里的logits预测出来的,logits[:, 0, -8:]在前有限轮对话里都没变,导致pred_strat_id一直不变。
已修复. DialoGPT模型应该取logits[:, -1, -8:]
感谢
ESC在response阶段,是基于历史对话内容 预测当前阶段的response 应该使用哪种策略吧?数据集设计也是这样的。在实际应用场景中,如果对话轮次变多,感觉很容易造成response的内容和策略对应不上的问题。相比之下,如果是对response的内容单独做多分类模型,分类成某一个策略类别,这样或许比较容易对应上,但这似乎又与使用策略回复的初衷不太一样。不知道您是怎么理解的。
策略不一定是完全可控的。可以单独用分类模型识别回复的策略
在进行交互的 前有限轮对话中,当所有对话内容tokenize之后的总长度小于max_input_length时,AI回复的strategy一直保持不变,直到tokenize之后的总长度大于max_input_length后,因为这时候开始做max_input_length的截断了,strat_dialogpt.py中predict_strategy函数中切片的logits才会变化,AI回复的strategy才会变,想问一下,是就这么设计的,还是其他原因啊?