Closed carlthome closed 6 years ago
It seems useful, ie, there's Layer Normalization which got good performance which is transpose of batch normalization, and Weight Normalization
I wonder if it should be called something else though, since batch
in batch_norm
refers to batch dimension (0'th dimension).
I don't see a way to specify axis in tf.nb.batch_normalization
tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)
According to its docstring tf.nn.batch_normalization
infers the axes from the input tensors (I assume x
, mean
, variance
etc. and that they're expected to match), so it didn't really need an axis
parameter, I guess?
By the way, my feature request applies to the layer norm layer too. That too is useful to be able to apply to several dimensions instead of just the last, particularly for RNNs where batch norm is... tricky.
Oh, good point, it looks like it can normalize over arbitrary axis, taking specification from the missing dimension in mean
/variance
tensors
What is the problem with using tf.nn.batch_normalization
?
tf.nn.batch_normalization
is more low-level. Unlike the layers
version it doesn't handle estimation of population statistics as decaying moving averages during training; it doesn't make a distinction between test time and train time statistics on its own; it requires you to manually introduce trainable variables for beta and gamma, and so on. It's rather tricky to setup, in fact, as it's basically just the mathematical operation and nothing more.
Could you outline what the API change would look like?
API change: tf.contrib.layers.batch_norm(..., axis=[0, 1, 2])
, just as tf.nn.moments
.
For backwards compatibility (where all but the last dimension should be reduced) axis=None
is necessary as the default argument because the shape of the input tensor is not known, with documented behavior that it aggregates every dimension except for the last, regardless of being preceded by a fully-connected operation, convolution operation or whatever.
@fchollet what do you think about this?
For backwards compatibility (where all but the last dimension should be reduced) axis=None is necessary as the default argument because the shape of the input tensor is not known
I don't understand this part? What does this mean?
The default value for axis
should be -1
, which is the feature axis is essentially every situation except image features in data_format
NCHW.
I had not heard before that one would want to normalize on more than the features axis. But we can consider the API change if it seems worth it. By which I mean we can consider allowing a tuple of axes as the axis
argument: the axes to normalize over, e.g. [-2, -1]
or [1, 2]
.
What does this mean?
Say tensors are [batch_size, frames, height, width, channels]
(video). Then axis=None
could assume the user wants means/variances as [batch_size, frames * height * width * channels]
and not [batch_size, frames, height, width]
(axis=-1
).
I had not heard before that one would want to normalize on more than the features axis. But we can consider the API change if it seems worth it. By which I mean we can consider allowing a tuple of axes as the axis argument: the axes to normalize over, e.g. [-2, -1] or [1, 2].
A common use case could be RNNs for audio and video, where you'd like to normalize whole sequences.
This would be a good feature to add to the existing batch_norm (in contrib).
We are going to close this issue. Contributions are welcome! If you want to contribute, feel free to reopen this issue and link your PR to this issue.
When using
tf.nn.batch_normalization
andtf.nn.moments
it's possible to choose which axes are aggregated and thus which dimensions are normalized. Fortf.contrib.layers.batch_norm
andtf.contrib.layers.layer_norm
the assumption is that only one dimension is normalized. There are use cases where it would be nice to independently normalize across both width and filters (for example audio spectrograms, where "width" could be seen as the time axis). Could the layers API be changed to be as flexible astf.nn.batch_normalization
?