murray-z / OneStop_QAMaker

采用一个模型同时实现问题生成和答案生成
26 stars 3 forks source link

报错 RuntimeError: shape '[64, 12, 64]' is invalid for input of size 6291456 #2

Closed Ulov888 closed 1 year ago

Ulov888 commented 1 year ago

运行代码直接报错了,不知道为什么维度对不上 image debug了一下,看了下qkv的维度,可以解释下q的shape为(64,1,768) k和v的shape(64,128,768),这样可以work吗 image

Ulov888 commented 1 year ago

再提供下文字版 Traceback (most recent call last): File "/home/liulin/qamaker/train_model.py", line 62, in <module> train(model, Lambda) File "/home/liulin/qamaker/train_model.py", line 34, in train start_logits, end_logits, decoder_out = model(*batch[:4]) File "/home/liulin/miniconda3/envs/rasa3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/liulin/qamaker/onestop_qamaker.py", line 41, in forward attention_out, attention_weight = self.attention(q, k, v) File "/home/liulin/miniconda3/envs/rasa3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/liulin/miniconda3/envs/rasa3/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 1160, in forward attn_mask=attn_mask, average_attn_weights=average_attn_weights) File "/home/liulin/miniconda3/envs/rasa3/lib/python3.7/site-packages/torch/nn/functional.py", line 5122, in multi_head_attention_forward k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) RuntimeError: shape '[64, 12, 64]' is invalid for input of size 6291456

murray-z commented 1 year ago

我直接运行没有遇到这个错误,不过qkv的维度确实有点问题,我已更新了代码,pytorch1.8.1 我运行5个epoch,loss是在下降的,最后拿了一个句子测试了一下: greedy res: {'question': '中 国 的 首 都 是 什 么 ?', 'answer': '北京'} beam_search res: {'question': '中 国 的 首 都 是 什 么 ?', 'answer': '北京'} random_sample res: [{'question': '中 国 的 首 都 是 什 么 ?', 'answer': '北京'}, {'question': '中 国 的 首 都 是 哪 里 ?', 'answer': '北京'}]

murray-z commented 1 year ago

这是论文的一个不足吧,只能生成一对。 可能也和训练集有关系,如果文本标注了多个,在解码的时候,采用随机采样,对同一个文本多生成几次。

Ulov888 commented 1 year ago

代码work了,如果想要进一步提升生成QA对的质量,应该从哪些方面提升呢

CaptainDP commented 1 year ago

代码work了,如果想要进一步提升生成QA对的质量,应该从哪些方面提升呢

可以尝试基于数据集:SQuAD、NewsQA、DuReader做模型预训练,在基于业务数据FT,这样效果会好些,我们正在尝试训练预训练