fang98525 / nlp-TextGeneration

采用seq2seq架构基于bert实现的文本生成模型, 该项目范式为对于给定文本和一个答案,生成相关的问题
1 stars 0 forks source link

关于model_NEZHA部分代码 #1

Open Biaocsu opened 2 years ago

Biaocsu commented 2 years ago

您好,非常感谢您的开源,很开心看到您的项目

我在查看您的代码时,发现有如下调用

from NEZHA.model_NEZHA import NEZHAConfig
from NEZHA import NEZHA_utils

鉴于未找到该代码,我认为来源于官方NEZHA,但却未发现,最后在该处找到相关代码示例

问题: 不知道是否找到的代码经过修改还是其他原因,出现报错信息,如下:

0it [00:01, ?it/s]
Traceback (most recent call last):
  File "train_fine_tune.py", line 338, in <module>
    train(train_iter, dev_iter, config=config)
  File "train_fine_tune.py", line 211, in train
    _,loss = model(input_ids=input_ids_list, token_type_ids=segment_ids_list, labels=label_ids_list)
  File "/home/anaconda3/envs/liubiao/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liubiao/zhishiku/QA-QG/QG/model.py", line 181, in forward
    output_hidden_states=True
  File "/home/anaconda3/envs/liubiao/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liubiao/zhishiku/QA-QG/QG/NEZHA/model_NEZHA.py", line 549, in forward
    extended_attention_mask)
  File "/home/anaconda3/envs/liubiao/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liubiao/zhishiku/QA-QG/QG/NEZHA/model_NEZHA.py", line 476, in forward
    hidden_states = layer_module(all_encoder_layers[i], attention_mask)
  File "/home/anaconda3/envs/liubiao/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liubiao/zhishiku/QA-QG/QG/NEZHA/model_NEZHA.py", line 458, in forward
    attention_output = self.attention(hidden_states, attention_mask)
  File "/home/anaconda3/envs/liubiao/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liubiao/zhishiku/QA-QG/QG/NEZHA/model_NEZHA.py", line 415, in forward
    self_output = self.self(input_tensor, attention_mask)
  File "/home/anaconda3/envs/liubiao/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liubiao/zhishiku/QA-QG/QG/NEZHA/model_NEZHA.py", line 371, in forward
    attention_probs_t = attention_probs.permute(2, 0, 1, 3)
RuntimeError: number of dims don't match in permute

请帮忙分析下原因,或者麻烦将该部分相关代码也上传下,非常感谢,谢谢

Biaocsu commented 2 years ago

我根据您另一个项目Joint_MRC找到了NEZHA相关代码,我认为这是您所使用的,但发现项目还是无法跑通,请求验证一下,非常感谢。个人觉得您的代码有很多值得学习的地方,所以准备好好研究下,但现在经常跑不通让人很苦恼

Biaocsu commented 2 years ago

您好,如您有时间,非常期待您的解答,我会非常感激您,因为没找到其他关于问题生成细节特别详细的代码,您的代码非常有学习参考价值。 我修改了您的代码,现在可以跑通,但是在预测时出现这样的情况(预测时无任何问题内容生成),不知道什么原因,想着您的代码应该不会出现重大逻辑问题: image 不管您是否知道原因,或是否愿意帮忙解答,都希望您回复一下,谢谢