Open sandeep-krishnamurthy opened 6 years ago
Is this issue still open or just forgotten to be closed?
For users there is no impact, this issue is just to track how BatchNorm operator is implemented in MXNet backend.
Since we used MXNet operator directly and MXNet BatchNorm is using use_global_stats
flag to freeze moving mean and var weights. You need to specify trainable=False
during layer construction for freezing batchnorm operator.
For example:
x = Input(shape=(3,), name='x')
f = Dense(10, name='h1')(x)
f = BatchNormalization(name='bn1')(f)
f = Activation('relu', name='r1')(f)
y = Dense(1, name='y')(f)
model = Model(inputs=[x], outputs=[y])
model.compile(loss='binary_crossentropy', optimizer='sgd')
model.fit(data, label, batch_size=5, epochs=10, verbose=1)
_, fname = tempfile.mkstemp('.h5')
model.save_weights(fname)
# reconstruct model with trainable=False except last layer
x = Input(shape=(3,), name='x')
f = Dense(10, name='h1', trainable=False)(x)
f = BatchNormalization(name='bn1', trainable=False)(f)
f = Activation('relu', name='r1', trainable=False)(f)
y = Dense(1, name='y')(f)
loaded = Model(inputs=[x], outputs=[y])
loaded.load_weights(fname)
loaded.compile(loss='binary_crossentropy', optimizer='sgd')
loaded.fit(data, label, batch_size=5, epochs=10, verbose=1)
For full example, refer to this test
Reference PR: https://github.com/awslabs/keras-apache-mxnet/pull/252 Reference on MXNet BatchNorm freezing: https://discuss.mxnet.io/t/how-freeze-batchnorm-layer-in-symbolblock/3949
MXNet backend uses mxnet batchnorm operator directly without going through Keras batchnorm normalization layer. Reason: MXNet do not support