Open kakaxzc opened 1 year ago
您好,这个是之前改版时候,example没有更改过来,应该按照下述这样修改一下就可以了,也可以升级到最新的0.3.7,最新版本不需要convert权重,仅需使用bert4torch_config.json就可以加载了
def forward(self, outputs, y_true):
y_pred = outputs[-1]
y_pred = y_pred.reshape(-1, y_pred.shape[-1])
return super().forward(y_pred, y_true)
@AutoRegressiveDecoder.wraps(default_rtype='logits')
def predict(self, inputs, output_ids, states):
res = model.decoder.predict([output_ids] + inputs)
return res[-1][:, -1, :] if isinstance(res, list) else res[:, -1, :] # 保留最后一位
问题修复了,感谢~!
你好 我原本的bert4torch版本是0.2.8执行task_seq2seq_autotitle_csl_mt5等一些类似模型没有问题,但是版本升级到0.3.4发生问题 在下面这个方法中outputs值返回2个值 class CrossEntropyLoss(nn.CrossEntropyLoss): def init(self, kwargs): super().init(kwargs)
如果去掉一个的话 在下面这部分的return地方会报错。 请问要如何解决 class AutoTitle(AutoRegressiveDecoder): """seq2seq解码器 """ @AutoRegressiveDecoder.wraps(default_rtype='logits') def predict(self, inputs, output_ids, states):
inputs中包含了[decoder_ids, encoder_hidden_state, encoder_attention_mask]