titu1994 / keras-squeeze-excite-network

Implementation of Squeeze and Excitation Networks in Keras
MIT License
400 stars 118 forks source link
deep-learning keras squeeze-excite

Squeeze and Excitation Networks in Keras

Implementation of Squeeze and Excitation Networks in Keras 2.0.3+.

squeeze-excite-block

Models

Current models supported :

Additional models (not from the paper, not verified if they improve performance)

Squeeze and Excitation block

The block is simple to implement in Keras. It composes of a GlobalAveragePooling2D, 2 Dense blocks and an elementwise multiplication. Shape inference can be done automatically in Keras. It can be imported from se.py.

from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Permute, multiply
import tensorflow.keras.backend as K

def squeeze_excite_block(tensor, ratio=16):
    init = tensor
    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
    filters = init._keras_shape[channel_axis]
    se_shape = (1, 1, filters)

    se = GlobalAveragePooling2D()(init)
    se = Reshape(se_shape)(se)
    se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
    se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)

    if K.image_data_format() == 'channels_first':
        se = Permute((3, 1, 2))(se)

    x = multiply([init, se])
    return x

Addition of Squeeze and Excitation blocks to Inception and ResNet blocks

se-architectures SE-ResNet-architecture