ShusenTang / Dive-into-DL-PyTorch

本项目将《动手学深度学习》(Dive into Deep Learning)原书中的MXNet实现改为PyTorch实现。
http://tangshusen.me/Dive-into-DL-PyTorch
Apache License 2.0
18.2k stars 5.38k forks source link

5.10.2 「从零开始实现批量归一化」 的实现可以优化 #58

Open huiget opened 4 years ago

huiget commented 4 years ago

bug描述 文中是这样写的:

mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)

而原文是这样的:

mean = X.mean(axis=(0, 2, 3), keepdims=True)

事实上,pytorch 也支持类似写法,如下:

mean = X.mean(dim=(0, 2, 3), keepdim=True)

这种写法除了简洁外,数值上应该更精确一些。 版本信息 pytorch: 1.3.1 ...

huiget commented 4 years ago

@ShusenTang

CaiyuZhang commented 4 years ago

torch.mean(input, dim, keepdim=False, out=None) → Tensor Parameters: dim (int or tuple of python:ints) – the dimension or dimensions to reduce. 你说的对,这是官方文档的描述,dim参数可以传入tuple