szagoruyko / wide-residual-networks

3.8% and 18.3% on CIFAR-10 and CIFAR-100
http://arxiv.org/abs/1605.07146
BSD 2-Clause "Simplified" License
1.3k stars 293 forks source link

Fix initialization to use MSRA init #41

Closed juesato closed 7 years ago

juesato commented 7 years ago

I believe the stddev of the initialized weights is a factor of sqrt(2) too high in the current Pytorch implementation.

For comparison, the current Lua Torch implementation uses sqrt(2) rather than 2 in the numerator, which is what I would expect from He/MSRA initialization, since Var = 2 / fan_in. https://github.com/szagoruyko/wide-residual-networks/blob/master/models/utils.lua#L6

Does this seem right to you? I unfortunately don't have the time or resources to re-run the code right now.

szagoruyko commented 7 years ago

@juesato thanks! I am using torch.nn.init.kaiming_normal which does exactly that, and it doesn't seem to change the results:

def conv_params(ni, no, k=1):
    return cast(kaiming_normal(torch.Tensor(no, ni, k, k)))

def linear_params(ni, no):
    return cast({'weight': kaiming_normal(torch.Tensor(no, ni)), 'bias': torch.zeros(no)})
szagoruyko commented 7 years ago

fixed in #42 , thanks!

juesato commented 7 years ago

Awesome, thanks!