I try to transfer the batchnorm2d layer, but it fails......I really want to know what's the difference between the two. Thanks! code:
# Create dummy data data = torch.rand(1,3,2,2) print(data) data_keras = data.numpy() data_pytorch = Variable(data) da= nn.BatchNorm2d(3,momentum=0.99,eps=1e-5) def lenet_keras(): model = Sequential() model.add(layers.BatchNormalization(axis=1,momentum=0.99,epsilon= 1e-5, center=True, scale=True, beta_initializer=initializers.get('zeros'), gamma_initializer=initializers.get('ones'), input_shape = (3,2,2))) return model le = lenet_keras() le.summary() wei = da.weight.detach().numpy() bi = da.bias.detach().numpy() rm = da.running_mean.detach().numpy() rv = da.running_var.detach().numpy() #print(bi) #wei = np.transpose(wei,(2,3,1,0)) trans = [wei,bi,rv,rm] print(trans) weight = le.layers[0].set_weights(trans) weight = le.layers[0].get_weights() print(weight) print(da(data_pytorch)) print(da(data_pytorch).shape) dat = le.predict(data_keras) print(dat) print(dat.shape)
tensor([[[[0.6875, 0.5928], [0.2540, 0.3808]], [[0.9787, 0.3589], [0.5710, 0.3397]], [[0.2124, 0.8893], [0.0698, 0.6545]]]]) _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= batch_normalization_17 (Batc (None, 3, 2, 2) 12 ================================================================= Total params: 12 Trainable params: 6 Non-trainable params: 6 _________________________________________________________________ [array([1., 1., 1.], dtype=float32), array([0., 0., 0.], dtype=float32), array([1., 1., 1.], dtype=float32), array([0., 0., 0.], dtype=float32)] [array([1., 1., 1.], dtype=float32), array([0., 0., 0.], dtype=float32), array([1., 1., 1.], dtype=float32), array([0., 0., 0.], dtype=float32)] tensor([[[[ 1.2217, 0.6677], [-1.3158, -0.5735]], [[ 1.6204, -0.7903], [ 0.0348, -0.8649]], [[-0.7396, 1.3113], [-1.1717, 0.6000]]]], grad_fn=<NativeBatchNormBackward>) torch.Size([1, 3, 2, 2]) [[[[ -98.82477 -128.75443 ] [-235.9049 -195.80484 ]] [[ -6.7411804 -202.74635 ] [-135.65837 -208.8067 ]] [[-249.05997 -35.020264 ] [-294.15558 -109.255936 ]]]] (1, 3, 2, 2)
Thanks, solved~
