tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 613 forks source link

Unexpected sensitivity of group normalisation results to batch size #2745

Open wsmigaj opened 2 years ago

wsmigaj commented 2 years ago

System information

Describe the bug

Results produced by group normalisation differ markedly depending on whether the batch size == 1 or not.

Code to reproduce the issue

The following code applies group normalisation along the last axis, first to a batch of 4 slices, then to individual slices, and finally to pairs of consecutive slices. In principle, the results obtained in each case should be the same, since group normalisation is not done along the batch dimension.

import tensorflow as tf
import tensorflow_addons as tfa

# Dimensions
b = 4
h = 600
w = 400
g = 3
c = g * 8

gn = tfa.layers.GroupNormalization(groups=g, axis=-1)
gn.build([None, None, None, c])

input = tf.random.stateless_uniform([b, h, w, c], seed=[1, 2])

# Apply group normalisation to the whole batch (4 slices) at once
output_b4 = gn.call(input)

# Apply group normalisation to each slice individually
output_0 = gn.call(input[0:1])
output_1 = gn.call(input[1:2])
output_2 = gn.call(input[2:3])
output_3 = gn.call(input[3:4])
output_b1 = tf.concat([output_0, output_1, output_2, output_3], axis=0)
tf.print("Batch size 1 vs batch size 4:", tf.reduce_max(tf.abs(output_b1 - output_b4)))

# Apply group normalisation to pairs of slices
output_01 = gn.call(input[0:2])
output_23 = gn.call(input[2:4])
output_b2 = tf.concat([output_01, output_23], axis=0)
tf.print("Batch size 2 vs batch size 4:", tf.reduce_max(tf.abs(output_b2 - output_b4)))

Output:

Batch size 1 vs batch size 4: 0.000871777534
Batch size 2 vs batch size 4: 2.38418579e-07

So the difference between results obtained for batch sizes 2 and 4 is on the order of machine precision, but that between results obtained for batch sizes 1 and 4 is three orders of magnitude larger.

Other info / logs

The difference is introduced by the call to tf.nn.moments() in GroupNormalization._apply_normalization(). The Reduce operations executed by moments() reshape the input tensor differently depending on whether its first dimension is 1 (and therefore it doesn't matter whether it is reduced over or not) or not (and therefore it must not be reduced over). The determines whether or not all elements with the same batch and group index are located next to each other in memory, and probably affects the order in which these elements are added together by Eigen::Tensor::reduce(). The difference in the final result is then a consequence of the non-associativity of floating-point addition.

The problem can be worked around by transposing the input tensor to a channels-first format, in which all axes not participating in the reduction (batch and group index) are located at the start of the axes list. However, it would be more user-friendly for this transpose to be done automatically inside GroupNormalization. If this sounds reasonable, I'm happy to open a PR patching GroupNormalization in this way.

bhack commented 2 years ago

Can you try to add a new test to cover this case?

wsmigaj commented 2 years ago

Can you try to add a new test to cover this case?

Yes -- I've added a new test in #2746.