mindspore-courses / step_into_llm

MindSpore online courses: Step into LLM
Apache License 2.0
374 stars 81 forks source link

transformer loss函数计算有误 #61

Open ywbgithub opened 1 month ago

ywbgithub commented 1 month ago

def forward(enc_inputs, dec_inputs): """前向网络 enc_inputs: [batch_size, src_len] dec_inputs: [batch_size, trglen] """ logits, , , = model(enc_inputs, dec_inputs[:, :-1], src_pad_idx, trg_pad_idx)

targets = dec_inputs[:, 1:].view(-1)
loss = loss_fn(logits, targets)

return loss
ywbgithub commented 1 month ago

logit数组 【batch_size, src_len, trg_vocab_size】 targets数组【batch_size, src_len】 两个数组维数不一致,而且对这两个数组使用loss_fn 没有实际意义

ywbgithub commented 1 month ago

logit数组 【batch_size, src_len, trg_vocab_size】 targets数组【batch_size, src_len】 两个数组维数不一致,而且对这两个数组使用loss_fn 没有实际意义