Open ywbgithub opened 1 month ago
logit数组 【batch_size, src_len, trg_vocab_size】 targets数组【batch_size, src_len】 两个数组维数不一致,而且对这两个数组使用loss_fn 没有实际意义
logit数组 【batch_size, src_len, trg_vocab_size】 targets数组【batch_size, src_len】 两个数组维数不一致,而且对这两个数组使用loss_fn 没有实际意义
def forward(enc_inputs, dec_inputs): """前向网络 enc_inputs: [batch_size, src_len] dec_inputs: [batch_size, trglen] """ logits, , , = model(enc_inputs, dec_inputs[:, :-1], src_pad_idx, trg_pad_idx)