In LMOps/minillm/transformers/src/transformers/mpu /cross_entropy.py/_ParallelSoftCrossEntropyLoss line144
the ceLoss = torch.log(sum_exp_logits) - sum_targets_softmax_logits
but the whole ceLoss should = softmax_logits_target * torch.log(sum_exp_logits) - sum_targets_softmax_logits.
In LMOps/minillm/transformers/src/transformers/mpu /cross_entropy.py/_ParallelSoftCrossEntropyLoss line144 the ceLoss = torch.log(sum_exp_logits) - sum_targets_softmax_logits but the whole ceLoss should = softmax_logits_target * torch.log(sum_exp_logits) - sum_targets_softmax_logits.
Thanks!