Tongjilibo / bert4torch

An elegent pytorch implement of transformers
https://bert4torch.readthedocs.io/
MIT License
1.23k stars 153 forks source link

bert4torch版本0.2.8升级到0.3.4问题 #157

Open kakaxzc opened 1 year ago

kakaxzc commented 1 year ago

你好 我原本的bert4torch版本是0.2.8执行task_seq2seq_autotitle_csl_mt5等一些类似模型没有问题,但是版本升级到0.3.4发生问题 在下面这个方法中outputs值返回2个值 class CrossEntropyLoss(nn.CrossEntropyLoss): def init(self, kwargs): super().init(kwargs)

def forward(self, outputs, y_true):
    _, _, y_pred = outputs
    y_pred = y_pred.reshape(-1, y_pred.shape[-1])
    return super().forward(y_pred, y_true)

如果去掉一个的话 在下面这部分的return地方会报错。 请问要如何解决 class AutoTitle(AutoRegressiveDecoder): """seq2seq解码器 """ @AutoRegressiveDecoder.wraps(default_rtype='logits') def predict(self, inputs, output_ids, states):

inputs中包含了[decoder_ids, encoder_hidden_state, encoder_attention_mask]

    # 保留最后一位
    return model.decoder.predict([output_ids] + inputs)[-1][:, -1, :]
Tongjilibo commented 12 months ago

您好,这个是之前改版时候,example没有更改过来,应该按照下述这样修改一下就可以了,也可以升级到最新的0.3.7,最新版本不需要convert权重,仅需使用bert4torch_config.json就可以加载了

def forward(self, outputs, y_true):
    y_pred = outputs[-1]
    y_pred = y_pred.reshape(-1, y_pred.shape[-1])
    return super().forward(y_pred, y_true)

@AutoRegressiveDecoder.wraps(default_rtype='logits')
def predict(self, inputs, output_ids, states):
    res = model.decoder.predict([output_ids] + inputs)
    return res[-1][:, -1, :] if isinstance(res, list) else res[:, -1, :]  # 保留最后一位
kakaxzc commented 12 months ago

问题修复了,感谢~!