awslabs / keras-apache-mxnet

[DEPRECATED] Amazon Deep Learning's Keras with Apache MXNet support
https://github.com/awslabs/keras-apache-mxnet/wiki
Other
290 stars 65 forks source link

MXNet backend uses native batch_norm operator #23

Open sandeep-krishnamurthy opened 6 years ago

sandeep-krishnamurthy commented 6 years ago

MXNet backend uses mxnet batchnorm operator directly without going through Keras batchnorm normalization layer. Reason: MXNet do not support

Daniel-M commented 5 years ago

Is this issue still open or just forgotten to be closed?

sandeep-krishnamurthy commented 5 years ago

For users there is no impact, this issue is just to track how BatchNorm operator is implemented in MXNet backend.

roywei commented 4 years ago

Since we used MXNet operator directly and MXNet BatchNorm is using use_global_stats flag to freeze moving mean and var weights. You need to specify trainable=False during layer construction for freezing batchnorm operator.

For example:

x = Input(shape=(3,), name='x')
f = Dense(10, name='h1')(x)
f = BatchNormalization(name='bn1')(f)
f = Activation('relu', name='r1')(f)
y = Dense(1, name='y')(f)

model = Model(inputs=[x], outputs=[y])
model.compile(loss='binary_crossentropy', optimizer='sgd')
model.fit(data, label, batch_size=5, epochs=10, verbose=1)
_, fname = tempfile.mkstemp('.h5')
model.save_weights(fname)

# reconstruct model with trainable=False except last layer
x = Input(shape=(3,), name='x')
f = Dense(10, name='h1', trainable=False)(x)
f = BatchNormalization(name='bn1', trainable=False)(f)
f = Activation('relu', name='r1', trainable=False)(f)
y = Dense(1, name='y')(f)

loaded = Model(inputs=[x], outputs=[y])
loaded.load_weights(fname)
loaded.compile(loss='binary_crossentropy', optimizer='sgd')

loaded.fit(data, label, batch_size=5, epochs=10, verbose=1)

For full example, refer to this test

Reference PR: https://github.com/awslabs/keras-apache-mxnet/pull/252 Reference on MXNet BatchNorm freezing: https://discuss.mxnet.io/t/how-freeze-batchnorm-layer-in-symbolblock/3949