Open Svpwm opened 3 years ago
我改了一下bert_seq2seq源码,你去看一下源码的输出就明白了
我改了一下bert_seq2seq源码,你去看一下源码的输出就明白了
方便看一下你改了bert_seq2seq的什么地方吗?以及怎么改的。 我的邮箱:496674663@qq.com
就是取了句子最后一个单词的hidden输出,很简单的,源码目前我这更新了一下,还得重新改
就是取了句子最后一个单词的hidden输出,很简单的,源码目前我这更新了一下,还得重新改
多谢,问题解决啦~
我也遇到了这样一个问题,输出的维度为[826,21128],不应该是[826,768]吗?
最后是对全部词做预测啦
所以我要把linear的输入大小改为(21128,3)吗?
所以我要把linear的输入大小改为(21128,3)吗?
不是的,你需要做的是修改 项目所依赖的bert_seq2seq 这个项目中 seq2seq 的输出,return predictions->return predictions, squence_out 你可以把当前model的输出打印出来,其实是只有一个[batch_size,seq_length,vocab_size]的变量。
大佬我数据对不齐,能看一下你的bert-seq2seq吗
方便看一下你改了bert_seq2seq的什么地方吗?以及怎么改的。 我的邮箱:1135383621@qq.com
晚点帮你看一下
发自我的iPhone
------------------ 原始邮件 ------------------ 发件人: mlm22222 @.> 发送时间: 2022年9月21日 16:26 收件人: Jeryi-Sun/SPACES-Pytorch @.> 抄送: Svpwm @.>, Author @.> 主题: Re: [Jeryi-Sun/SPACES-Pytorch] bert_seq2seq输出维度和变量不一致 (Issue #5)
方便看一下你改了bert_seq2seq的什么地方吗?以及怎么改的。 @.***
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>
谢谢大佬,方便加大佬QQ吗 @Svpwm
想看看大佬bert-seq2seq怎么改的
晚点帮你看一下 发自我的iPhone … ------------------ 原始邮件 ------------------ 发件人: mlm22222 @.> 发送时间: 2022年9月21日 16:26 收件人: Jeryi-Sun/SPACES-Pytorch @.> 抄送: Svpwm @.>, Author @.> 主题: Re: [Jeryi-Sun/SPACES-Pytorch] bert_seq2seq输出维度和变量不一致 (Issue #5) 方便看一下你改了bert_seq2seq的什么地方吗?以及怎么改的。 @. — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.>
请问您是怎么改的bert_seq2seq,可以发我一份吗1741702875@qq.com,谢谢!
请问,关于修改后的bert_seq2seq,可以发我一份吗3261174686@qq.com,谢谢!
` def forward(self, input_tensor, token_type_id, random_id, position_enc=None, target_ids=None, labels_id=None): input_tensor = input_tensor.to(self.device) token_type_id = token_type_id.to(self.device) random_id = random_id.to(self.device)
if position_enc is not None:
position_enc = position_enc.to(self.device)
if target_ids is not None :
target_ids = target_ids.to(self.device)
if labels_id is not None :
labels_id = labels_id.to(self.device)
input_shape = input_tensor.shape
seq_len = input_shape[1]
## 构建特殊的mask
ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device)
a_mask = ones.tril()
s_ex12 = token_type_id.unsqueeze(1).unsqueeze(2).float()
s_ex13 = token_type_id.unsqueeze(1).unsqueeze(3).float()
a_mask = (1.0 - s_ex12) * (1.0 - s_ex13) + s_ex13 * a_mask
enc_layers, pooled = self.bert(input_tensor, position_ids=position_enc, token_type_ids=token_type_id, attention_mask=a_mask, output_all_encoded_layers=True)
squence_out = enc_layers[-1] ## 取出来最后一层输出 (batch, seq_len, 768)
tokens_hidden_state, predictions = self.cls(squence_out)
if target_ids is not None:
predictions = predictions[:, :-1].contiguous()
target_mask = token_type_id[:, 1:].contiguous()
else :
return predictions`
https://github.com/eryihaha/SPACES-Pytorch/blob/195453d2729d08f9c21afa26e81e67df9fe7e9e1/seq2seq_model.py#L214
model_name=‘nezha’,model_class="seq2seq“ 请问一下这里我这边self.bert_model的输出维度是(batch_sizem, seq_len, vob_size),但是我看您这边定义的是seq2seq_predictions, hidden_state,能帮忙看一下么?
以及我找了一下bert_seq2seq的关于这个函数的输出: