ShusenTang / Dive-into-DL-PyTorch

本项目将《动手学深度学习》(Dive into Deep Learning)原书中的MXNet实现改为PyTorch实现。
http://tangshusen.me/Dive-into-DL-PyTorch
Apache License 2.0
18.32k stars 5.4k forks source link

loss计算错误 #74

Open Yzichen opened 4 years ago

Yzichen commented 4 years ago

bug描述 3.7 3.9中计算交叉熵损失函数时使用tf.nn.CrossEntropyLoss()时,已经作了平均。 但是计算每一个epoch的损失时,又除了整个训练集样本的个数n ,这样对吗?

版本信息 pytorch:1.13 torchvision:0.4.2 torchtext:无 ...

ShusenTang commented 4 years ago

你说得对,应该除以batch数,因为tf.nn.CrossEntropyLoss()时已经沿batch维作了平均。