tensorlayer / TensorLayer

Deep Learning and Reinforcement Learning Library for Scientists and Engineers
http://tensorlayerx.com
Other
7.34k stars 1.61k forks source link

batch normalization layer for data_format == 'channels_last' #1103

Open edwardzcl opened 4 years ago

edwardzcl commented 4 years ago

New Issue Checklist

Issue Description

according to batch normalization implement in TL, which can be find at "https://github.com/tensorlayer/tensorlayer/blob/v2.2.0/tensorlayer/layers/normalization.py", the mean and var are computed with the whole inputs, not channel-wise inputs when the init are set as data_format='channels_last'.

You can refer to the 220 line for self.channel_axis = -1 if data_format == 'channels_last' else 1 and the 282 line for self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis], the self.axes

Reproducible Code

class BatchNorm(Layer):
    """
    The :class:`BatchNorm` is a batch normalization layer for both fully-connected and convolution outputs.
    See ``tf.nn.batch_normalization`` and ``tf.nn.moments``.
    Parameters
    ----------
    decay : float
        A decay factor for `ExponentialMovingAverage`.
        Suggest to use a large value for large dataset.
    epsilon : float
        Eplison.
    act : activation function
        The activation function of this layer.
    is_train : boolean
        Is being used for training or inference.
    beta_init : initializer or None
        The initializer for initializing beta, if None, skip beta.
        Usually you should not skip beta unless you know what happened.
    gamma_init : initializer or None
        The initializer for initializing gamma, if None, skip gamma.
        When the batch normalization layer is use instead of 'biases', or the next layer is linear, this can be
        disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 <https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py>`__
    moving_mean_init : initializer or None
        The initializer for initializing moving mean, if None, skip moving mean.
    moving_var_init : initializer or None
        The initializer for initializing moving var, if None, skip moving var.
    num_features: int
        Number of features for input tensor. Useful to build layer if using BatchNorm1d, BatchNorm2d or BatchNorm3d,
        but should be left as None if using BatchNorm. Default None.
    data_format : str
        channels_last 'channel_last' (default) or channels_first.
    name : None or str
        A unique layer name.
    Examples
    ---------
    With TensorLayer
    >>> net = tl.layers.Input([None, 50, 50, 32], name='input')
    >>> net = tl.layers.BatchNorm()(net)
    Notes
    -----
    The :class:`BatchNorm` is universally suitable for 3D/4D/5D input in static model, but should not be used
    in dynamic model where layer is built upon class initialization. So the argument 'num_features' should only be used
    for subclasses :class:`BatchNorm1d`, :class:`BatchNorm2d` and :class:`BatchNorm3d`. All the three subclasses are
    suitable under all kinds of conditions.
    References
    ----------
    - `Source <https://github.com/ry/tensorflow-resnet/blob/master/resnet.py>`__
    - `stackoverflow <http://stackoverflow.com/questions/38312668/how-does-one-do-inference-with-batch-normalization-with-tensor-flow>`__
    """

    def __init__(
            self,
            decay=0.9,
            epsilon=0.00001,
            act=None,
            is_train=False,
            beta_init=tl.initializers.zeros(),
            gamma_init=tl.initializers.random_normal(mean=1.0, stddev=0.002),
            moving_mean_init=tl.initializers.zeros(),
            moving_var_init=tl.initializers.zeros(),
            num_features=None,
            data_format='channels_last',
            name=None,
    ):
        super(BatchNorm, self).__init__(name=name, act=act)
        self.decay = decay
        self.epsilon = epsilon
        self.data_format = data_format
        self.beta_init = beta_init
        self.gamma_init = gamma_init
        self.moving_mean_init = moving_mean_init
        self.moving_var_init = moving_var_init
        self.num_features = num_features

        #self.channel_axis = -1 if data_format == 'channels_last' else 1
        ## add ##
        self.data_format = data_format
        self.axes = None

        if num_features is not None:
            self.build(None)
            self._built = True

        if self.decay < 0.0 or 1.0 < self.decay:
            raise ValueError("decay should be between 0 to 1")

        logging.info(
            "BatchNorm %s: decay: %f epsilon: %f act: %s is_train: %s" %
            (self.name, decay, epsilon, self.act.__name__ if self.act is not None else 'No Activation', is_train)
        )

                                      ## skip ##

    def forward(self, inputs):
        self._check_input_shape(inputs)
        ## add ##
        self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1

        if self.axes is None:
            self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]

        mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
        if self.is_train:
            # update moving_mean and moving_var
            self.moving_mean = moving_averages.assign_moving_average(
                self.moving_mean, mean, self.decay, zero_debias=False
            )
            self.moving_var = moving_averages.assign_moving_average(self.moving_var, var, self.decay, zero_debias=False)
            outputs = batch_normalization(inputs, mean, var, self.beta, self.gamma, self.epsilon, self.data_format)
        else:
            outputs = batch_normalization(
                inputs, self.moving_mean, self.moving_var, self.beta, self.gamma, self.epsilon, self.data_format
            )
        if self.act:
            outputs = self.act(outputs)
        return outputs

just delete line 220 code and add self.data_format = data_format in init, then add self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1 in forward.

Laicheng0830 commented 4 years ago

Thanks! there is a problem with this code, we will fix it immediately.