Jeryi-Sun / SPACES-Pytorch

苏神SPACE pytorch版本复现
MIT License
41 stars 4 forks source link

bert_seq2seq输出维度和变量不一致 #5

Open Svpwm opened 2 years ago

Svpwm commented 2 years ago

https://github.com/eryihaha/SPACES-Pytorch/blob/195453d2729d08f9c21afa26e81e67df9fe7e9e1/seq2seq_model.py#L214

model_name=‘nezha’,model_class="seq2seq“ 请问一下这里我这边self.bert_model的输出维度是(batch_sizem, seq_len, vob_size),但是我看您这边定义的是seq2seq_predictions, hidden_state,能帮忙看一下么?

以及我找了一下bert_seq2seq的关于这个函数的输出:

    if labels is not None:

        predictions = predictions[:, :-1].contiguous()
        target_mask = token_type_id[:, 1:].contiguous()
        loss = self.compute_loss(predictions, labels, target_mask)
        return predictions, loss 
    else :
        return predictions
Jeryi-Sun commented 2 years ago

我改了一下bert_seq2seq源码,你去看一下源码的输出就明白了

Svpwm commented 2 years ago

我改了一下bert_seq2seq源码,你去看一下源码的输出就明白了

方便看一下你改了bert_seq2seq的什么地方吗?以及怎么改的。 我的邮箱:496674663@qq.com

Jeryi-Sun commented 2 years ago

就是取了句子最后一个单词的hidden输出,很简单的,源码目前我这更新了一下,还得重新改

Svpwm commented 2 years ago

就是取了句子最后一个单词的hidden输出,很简单的,源码目前我这更新了一下,还得重新改

多谢,问题解决啦~

Wenpeng-huang commented 2 years ago

我也遇到了这样一个问题,输出的维度为[826,21128],不应该是[826,768]吗?

image image

Jeryi-Sun commented 2 years ago

最后是对全部词做预测啦

Wenpeng-huang commented 2 years ago

所以我要把linear的输入大小改为(21128,3)吗?

Svpwm commented 2 years ago

所以我要把linear的输入大小改为(21128,3)吗?

不是的,你需要做的是修改 项目所依赖的bert_seq2seq 这个项目中 seq2seq 的输出,return predictions->return predictions, squence_out 你可以把当前model的输出打印出来,其实是只有一个[batch_size,seq_length,vocab_size]的变量。

mlm22222 commented 2 years ago

image 大佬我数据对不齐,能看一下你的bert-seq2seq吗

mlm22222 commented 2 years ago

方便看一下你改了bert_seq2seq的什么地方吗?以及怎么改的。 我的邮箱:1135383621@qq.com

Svpwm commented 2 years ago

晚点帮你看一下

发自我的iPhone

------------------ 原始邮件 ------------------ 发件人: mlm22222 @.> 发送时间: 2022年9月21日 16:26 收件人: Jeryi-Sun/SPACES-Pytorch @.> 抄送: Svpwm @.>, Author @.> 主题: Re: [Jeryi-Sun/SPACES-Pytorch] bert_seq2seq输出维度和变量不一致 (Issue #5)

方便看一下你改了bert_seq2seq的什么地方吗?以及怎么改的。 @.***

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

mlm22222 commented 2 years ago

谢谢大佬,方便加大佬QQ吗 @Svpwm

mlm22222 commented 2 years ago

想看看大佬bert-seq2seq怎么改的

AlexNLP commented 1 year ago

晚点帮你看一下 发自我的iPhone ------------------ 原始邮件 ------------------ 发件人: mlm22222 @.> 发送时间: 2022年9月21日 16:26 收件人: Jeryi-Sun/SPACES-Pytorch @.> 抄送: Svpwm @.>, Author @.> 主题: Re: [Jeryi-Sun/SPACES-Pytorch] bert_seq2seq输出维度和变量不一致 (Issue #5) 方便看一下你改了bert_seq2seq的什么地方吗?以及怎么改的。 @. — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.>

请问您是怎么改的bert_seq2seq,可以发我一份吗1741702875@qq.com,谢谢!

haha123763 commented 1 year ago

请问,关于修改后的bert_seq2seq,可以发我一份吗3261174686@qq.com,谢谢!

mlm22222 commented 1 year ago

` def forward(self, input_tensor, token_type_id, random_id, position_enc=None, target_ids=None, labels_id=None): input_tensor = input_tensor.to(self.device) token_type_id = token_type_id.to(self.device) random_id = random_id.to(self.device)

    if position_enc is not None:
        position_enc = position_enc.to(self.device)
    if target_ids is not None :
        target_ids = target_ids.to(self.device)
    if labels_id is not None :
        labels_id = labels_id.to(self.device)

    input_shape = input_tensor.shape
    seq_len = input_shape[1]
    ## 构建特殊的mask
    ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device)
    a_mask = ones.tril()
    s_ex12 = token_type_id.unsqueeze(1).unsqueeze(2).float()
    s_ex13 = token_type_id.unsqueeze(1).unsqueeze(3).float()
    a_mask = (1.0 - s_ex12) * (1.0 - s_ex13) + s_ex13 * a_mask 

    enc_layers, pooled = self.bert(input_tensor, position_ids=position_enc, token_type_ids=token_type_id, attention_mask=a_mask, output_all_encoded_layers=True)

    squence_out = enc_layers[-1] ## 取出来最后一层输出 (batch, seq_len, 768)

    tokens_hidden_state, predictions = self.cls(squence_out)

    if target_ids is not None:

        predictions = predictions[:, :-1].contiguous()
        target_mask = token_type_id[:, 1:].contiguous()

    else :
        return predictions`