Closed juesato closed 7 years ago
@juesato thanks! I am using torch.nn.init.kaiming_normal
which does exactly that, and it doesn't seem to change the results:
def conv_params(ni, no, k=1):
return cast(kaiming_normal(torch.Tensor(no, ni, k, k)))
def linear_params(ni, no):
return cast({'weight': kaiming_normal(torch.Tensor(no, ni)), 'bias': torch.zeros(no)})
fixed in #42 , thanks!
Awesome, thanks!
I believe the stddev of the initialized weights is a factor of sqrt(2) too high in the current Pytorch implementation.
For comparison, the current Lua Torch implementation uses sqrt(2) rather than 2 in the numerator, which is what I would expect from He/MSRA initialization, since Var = 2 / fan_in. https://github.com/szagoruyko/wide-residual-networks/blob/master/models/utils.lua#L6
Does this seem right to you? I unfortunately don't have the time or resources to re-run the code right now.