tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone
https://tensorflow.org
Apache License 2.0
185.74k stars 74.21k forks source link

Feature request: let tf.layers.batch_normalization normalize over multiple axes #7091

Closed carlthome closed 6 years ago

carlthome commented 7 years ago

When using tf.nn.batch_normalization and tf.nn.moments it's possible to choose which axes are aggregated and thus which dimensions are normalized. For tf.contrib.layers.batch_norm and tf.contrib.layers.layer_norm the assumption is that only one dimension is normalized. There are use cases where it would be nice to independently normalize across both width and filters (for example audio spectrograms, where "width" could be seen as the time axis). Could the layers API be changed to be as flexible as tf.nn.batch_normalization?

yaroslavvb commented 7 years ago

It seems useful, ie, there's Layer Normalization which got good performance which is transpose of batch normalization, and Weight Normalization

I wonder if it should be called something else though, since batch in batch_norm refers to batch dimension (0'th dimension).

I don't see a way to specify axis in tf.nb.batch_normalization

tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)

carlthome commented 7 years ago

According to its docstring tf.nn.batch_normalization infers the axes from the input tensors (I assume x, mean, variance etc. and that they're expected to match), so it didn't really need an axis parameter, I guess?

By the way, my feature request applies to the layer norm layer too. That too is useful to be able to apply to several dimensions instead of just the last, particularly for RNNs where batch norm is... tricky.

yaroslavvb commented 7 years ago

Oh, good point, it looks like it can normalize over arbitrary axis, taking specification from the missing dimension in mean/variance tensors

drpngx commented 7 years ago

What is the problem with using tf.nn.batch_normalization?

carlthome commented 7 years ago

tf.nn.batch_normalization is more low-level. Unlike the layers version it doesn't handle estimation of population statistics as decaying moving averages during training; it doesn't make a distinction between test time and train time statistics on its own; it requires you to manually introduce trainable variables for beta and gamma, and so on. It's rather tricky to setup, in fact, as it's basically just the mathematical operation and nothing more.

drpngx commented 7 years ago

Could you outline what the API change would look like?

carlthome commented 7 years ago

API change: tf.contrib.layers.batch_norm(..., axis=[0, 1, 2]), just as tf.nn.moments.

For backwards compatibility (where all but the last dimension should be reduced) axis=None is necessary as the default argument because the shape of the input tensor is not known, with documented behavior that it aggregates every dimension except for the last, regardless of being preceded by a fully-connected operation, convolution operation or whatever.

martinwicke commented 7 years ago

@fchollet what do you think about this?

fchollet commented 7 years ago

For backwards compatibility (where all but the last dimension should be reduced) axis=None is necessary as the default argument because the shape of the input tensor is not known

I don't understand this part? What does this mean?

The default value for axis should be -1, which is the feature axis is essentially every situation except image features in data_format NCHW.

I had not heard before that one would want to normalize on more than the features axis. But we can consider the API change if it seems worth it. By which I mean we can consider allowing a tuple of axes as the axis argument: the axes to normalize over, e.g. [-2, -1] or [1, 2].

carlthome commented 7 years ago

What does this mean?

Say tensors are [batch_size, frames, height, width, channels] (video). Then axis=None could assume the user wants means/variances as [batch_size, frames * height * width * channels] and not [batch_size, frames, height, width] (axis=-1).

I had not heard before that one would want to normalize on more than the features axis. But we can consider the API change if it seems worth it. By which I mean we can consider allowing a tuple of axes as the axis argument: the axes to normalize over, e.g. [-2, -1] or [1, 2].

A common use case could be RNNs for audio and video, where you'd like to normalize whole sequences.

martinwicke commented 7 years ago

This would be a good feature to add to the existing batch_norm (in contrib).

dksb commented 6 years ago

We are going to close this issue. Contributions are welcome! If you want to contribute, feel free to reopen this issue and link your PR to this issue.