broadinstitute / keras-resnet

Keras package for deep residual networks
Other
300 stars 127 forks source link

Add option to freeze BatchNorm layers. #31

Closed hgaiser closed 6 years ago

hgaiser commented 6 years ago

This PR adds the ability to freeze BatchNorm layers, which is often done when finetuning algorithms (such as Faster RCNN). The freezing is implemented in a custom BatchNorm layer which calls Keras' BatchNorm function, while setting the two different training and trainable flags. The reason for a custom layer is because the alternative would look something like this:

    layer = keras.layers.BatchNormalization()
    layer.trainable = not freeze_bn
    output = layer(input, training=(not freeze_bn))

The custom layer changes this to:

output = keras_resnet.layers.BatchNormalization(freeze=freeze_bn)(input)

which saves lines and complexity.

0x00b1 commented 6 years ago

Awesome. Thanks, @hgaiser.