gmalivenko / pytorch2keras

PyTorch to Keras model convertor
https://pytorch2keras.readthedocs.io/en/latest/
MIT License
857 stars 143 forks source link

BatchNorm2d config not transferring properly #70

Closed jezza770 closed 5 years ago

jezza770 commented 5 years ago

The parameters for Pytorch's nn.BatchNorm2d do not copy correctly. Pytorch's default momentum is 0.1 however after training, the Keras model has momentum of 1. This causes nan's in the network after training. See the example below. It's not exactly minimum (copied in part from a personal project I'm working on) but it demonstrates the issue. Manually changing the momentum before compiling fixes the issue. Other parameters such as epsilon seem to be copied correctly.

import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
from pytorch2keras.converter import pytorch_to_keras
import keras
class PytorchModel(nn.Module):
def __init__(self):
    super(PytorchModel, self).__init__()
    self.layer1 = nn.Sequential(nn.Conv2d(1, 2, kernel_size=2, padding=1), nn.BatchNorm2d(2), nn.ReLU())

def forward(self, x):
    x = self.layer1(x)
    return x

pytorch_model = PytorchModel()
x_data = np.random.uniform(low=0.0, high=1.0, size=(1, 1, 2, 2))
x = Variable(torch.Tensor(x_data))
pytorch_output=pytorch_model(x)
print(pytorch_output)
keras_model = pytorch_to_keras(pytorch_model, x, [(1, None, 2)], verbose=False, names='short')
sm = keras.Sequential()
layer_configs=keras_model.get_config()['layers']
print(layer_configs)
print()
for config, layer in zip(layer_configs, keras_model.layers):
    if config['class_name'] == 'BatchNormalization':
        layer.momentum = 0.1
    sm.add(layer)
sgd = keras.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
sm.compile(loss='mean_squared_error', optimizer=sgd)
for layer in keras_model.layers:
    print(layer.get_config())
keras_output=sm.predict(x_data)
print(keras_output)
sm.fit(x_data,np.random.random((1,2,3,3)))
keras_output=sm.predict(x_data)
print(keras_output)
gmalivenko commented 5 years ago

Hello @jezza770. I reproduced the error and traced the issue up to ONNX export. I opened related PyTorch issue. So, for now the only way is to set proper momentum, as you do. Will wait for a PyTorch-ONNX issue to be solved.

gmalivenko commented 5 years ago

Hello @jezza770, I'm back with updates. It was ONNX-issue now fixed in master branch of PyTorch. Will wait for a new PyTorch release to get it work from the box.