keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.94k stars 19.46k forks source link

BatchNormalization moving_mean goes to infinity then nan. #9646

Closed bharris47 closed 5 years ago

bharris47 commented 6 years ago

Hi, I am running into an issue where my BatchNormalization moving_mean parameters quickly explode to infinity following Conv1D layers. I have check the output of the Conv1D layers and the moving mean seems to be simply incorrect. I am using the Adam optimizer and clipnorm and clipvalue seem to have no effect.

I am able to reproduce the issue on two separate machines on a Titan V and a Titan Xp.

def cnn_rnn(max_length, vocab_size,
                    char_embedding=16,
                    filter_counts=(256, 256, 256, 256, 256, 256),
                    sizes=(7, 7, 3, 3, 3, 3),
                    pool_sizes=(3, 3, None, None, None, 3),
                    rnn=CuDNNLSTM,
                    rnn_sizes=(1024, 1024)):
    """
    Character level CNN based on https://arxiv.org/pdf/1502.01710.pdf
    """
    document = Input(shape=(max_length,))
    features = Embedding(vocab_size, char_embedding)(document)
    features = Dropout(0.2)(features)
    for i, (filters, size, pool_size) in enumerate(zip(filter_counts, sizes, pool_sizes), 1):
        features = Conv1D(filters, size, activation='linear', name='conv_' + str(i), padding='same')(features)
        features = BatchNormalization(scale=False)(features)
        features = Activation('relu')(features)
        if pool_size:
            features = MaxPooling1D(pool_size)(features)

    for i, rnn_size in enumerate(rnn_sizes, 1):
        return_sequences = i < len(rnn_sizes)  # return single vector for last rnn in stack
        rnn_layer = Bidirectional(rnn(rnn_size, return_sequences=return_sequences), name='rnn_' + str(i))
        features = rnn_layer(features)

    features = Lambda(lambda x: K.l2_normalize(x, axis=-1))(features)
    model = Model(document, features)
    return model

I have tried moving BatchNormalization after activation, before pooling, after pooling. Nothing seems to help.

The output of this model is fed into a triplet model for training.

def triplet_model(encoder, input_shape):
    x_anchor = Input(shape=input_shape, name='anchor')
    x_related = Input(shape=input_shape, name='related')
    x_unrelated = Input(shape=input_shape, name='unrelated')

    h_anchor = encoder(x_anchor)
    h_related = encoder(x_related)
    h_unrelated = encoder(x_unrelated)

    summation = Dense(1, activation='linear', kernel_initializer='ones', bias_initializer='zeros', name='summation')
    summation.trainable = False

    d_related = Subtract()([h_anchor, h_related])
    d_unrelated = Subtract()([h_anchor, h_unrelated])

    # (a-b)**2
    d_related = Lambda(lambda val: val ** 2)(d_related)
    d_unrelated = Lambda(lambda val: val ** 2)(d_unrelated)

    # sum((a-b)**2)
    d_related = summation(d_related)
    d_unrelated = summation(d_unrelated)

    # concatenate both distances and apply softmax so we get values from 0-1
    output = Concatenate()([d_related, d_unrelated])
    output = Activation('softmax')(output)

    inputs = [x_anchor, x_related, x_unrelated]
    model = Model(inputs=inputs, outputs=output)
    return model

I added a Callback which logs the BatchNormalization statistics after each batch. Here are some sample logs of the BatchNormalization weights.

   1/1000 [..............................] - ETA: 2:03:47 - loss: 0.6831 - acc: 0.5625
batch_normalization_1 batch_normalization_1/beta:0 -3.478276e-05
batch_normalization_1 batch_normalization_1/moving_mean:0 0.00068033265
batch_normalization_1 batch_normalization_1/moving_variance:0 4.599308e-05

batch_normalization_2 batch_normalization_2/beta:0 2.8935632e-05
batch_normalization_2 batch_normalization_2/moving_mean:0 -0.010991309
batch_normalization_2 batch_normalization_2/moving_variance:0 -1.921591

batch_normalization_3 batch_normalization_3/beta:0 -8.179595e-05
batch_normalization_3 batch_normalization_3/moving_mean:0 0.01017458
batch_normalization_3 batch_normalization_3/moving_variance:0 0.51089174

batch_normalization_4 batch_normalization_4/beta:0 -6.206893e-05
batch_normalization_4 batch_normalization_4/moving_mean:0 0.06638892
batch_normalization_4 batch_normalization_4/moving_variance:0 -0.32122394

batch_normalization_5 batch_normalization_5/beta:0 -4.149045e-05
batch_normalization_5 batch_normalization_5/moving_mean:0 -0.06003024
batch_normalization_5 batch_normalization_5/moving_variance:0 0.3666916

batch_normalization_6 batch_normalization_6/beta:0 -0.00025191344
batch_normalization_6 batch_normalization_6/moving_mean:0 -0.04935637
batch_normalization_6 batch_normalization_6/moving_variance:0 -0.30286956

...

 151/1000 [===>..........................] - ETA: 7:49 - loss: 0.5681 - acc: 0.7556
batch_normalization_1 batch_normalization_1/beta:0 -0.0027314727
batch_normalization_1 batch_normalization_1/moving_mean:0 -5.404506e+33
batch_normalization_1 batch_normalization_1/moving_variance:0 0.000113652655

batch_normalization_2 batch_normalization_2/beta:0 0.0034339842
batch_normalization_2 batch_normalization_2/moving_mean:0 4.6608827e+35
batch_normalization_2 batch_normalization_2/moving_variance:0 1.927948

batch_normalization_3 batch_normalization_3/beta:0 0.0060493248
batch_normalization_3 batch_normalization_3/moving_mean:0 -4.1352656e+35
batch_normalization_3 batch_normalization_3/moving_variance:0 5.40427

batch_normalization_4 batch_normalization_4/beta:0 0.0015374201
batch_normalization_4 batch_normalization_4/moving_mean:0 -1.1334452e+36
batch_normalization_4 batch_normalization_4/moving_variance:0 4.2213097

batch_normalization_5 batch_normalization_5/beta:0 0.00087401713
batch_normalization_5 batch_normalization_5/moving_mean:0 -7.723977e+35
batch_normalization_5 batch_normalization_5/moving_variance:0 3.892437

batch_normalization_6 batch_normalization_6/beta:0 -0.0050984416
batch_normalization_6 batch_normalization_6/moving_mean:0 -1.1864312e+36
batch_normalization_6 batch_normalization_6/moving_variance:0 5.0928397

 152/1000 [===>..........................] - ETA: 7:48 - loss: 0.5678 - acc: 0.7569
batch_normalization_1 batch_normalization_1/beta:0 -0.0028489209
batch_normalization_1 batch_normalization_1/moving_mean:0 1.0809012e+34
batch_normalization_1 batch_normalization_1/moving_variance:0 0.00011406097

batch_normalization_2 batch_normalization_2/beta:0 0.0032967457
batch_normalization_2 batch_normalization_2/moving_mean:0 -9.321765e+35
batch_normalization_2 batch_normalization_2/moving_variance:0 1.9499233

batch_normalization_3 batch_normalization_3/beta:0 0.005909444
batch_normalization_3 batch_normalization_3/moving_mean:0 8.270531e+35
batch_normalization_3 batch_normalization_3/moving_variance:0 5.4213495

batch_normalization_4 batch_normalization_4/beta:0 0.0014220099
batch_normalization_4 batch_normalization_4/moving_mean:0 inf
batch_normalization_4 batch_normalization_4/moving_variance:0 4.233643

batch_normalization_5 batch_normalization_5/beta:0 0.00084148924
batch_normalization_5 batch_normalization_5/moving_mean:0 inf
batch_normalization_5 batch_normalization_5/moving_variance:0 4.616567

batch_normalization_6 batch_normalization_6/beta:0 -0.005221432
batch_normalization_6 batch_normalization_6/moving_mean:0 nan
batch_normalization_6 batch_normalization_6/moving_variance:0 5.510005

 153/1000 [===>..........................] - ETA: 7:47 - loss: 0.5671 - acc: 0.7584
batch_normalization_1 batch_normalization_1/beta:0 -0.0029070186
batch_normalization_1 batch_normalization_1/moving_mean:0 -2.1618023e+34
batch_normalization_1 batch_normalization_1/moving_variance:0 0.00014326644

batch_normalization_2 batch_normalization_2/beta:0 0.0032181628
batch_normalization_2 batch_normalization_2/moving_mean:0 nan
batch_normalization_2 batch_normalization_2/moving_variance:0 1.5394695

batch_normalization_3 batch_normalization_3/beta:0 0.0058730226
batch_normalization_3 batch_normalization_3/moving_mean:0 nan
batch_normalization_3 batch_normalization_3/moving_variance:0 5.4471445

batch_normalization_4 batch_normalization_4/beta:0 0.00138916
batch_normalization_4 batch_normalization_4/moving_mean:0 nan
batch_normalization_4 batch_normalization_4/moving_variance:0 4.452125

batch_normalization_5 batch_normalization_5/beta:0 0.0008047862
batch_normalization_5 batch_normalization_5/moving_mean:0 nan
batch_normalization_5 batch_normalization_5/moving_variance:0 3.9046226

batch_normalization_6 batch_normalization_6/beta:0 -0.005167995
batch_normalization_6 batch_normalization_6/moving_mean:0 nan
batch_normalization_6 batch_normalization_6/moving_variance:0 5.542812
bharris47 commented 6 years ago

Could there be some issue with running 3 texts per training example through the BatchNormalization layers? If my batch size is 128, could Keras/Tensorflow be summing the BatchNormalization inputs from (3 * 128) samples but only dividing by the 128 training examples?

AZweifels commented 6 years ago

I'm facing similar issue for multi-stream image classification if my branches contain trainable BN layers.
@bharris47 Did you found a solution for your issue?

bharris47 commented 6 years ago

I did not find a solution to this particular issue. I ended up using a different loss function and only needed to run one batch through the model per step which worked around the issue.

boozyguo commented 6 years ago

@AZweifels and @bharris47 @fchollet. I have the same issue. When using multi-label outputs, the trainable BN layers got nan moving_mean. I have try mobilenetv2, xception, inception-resnet-v2. Is it a keras bug?

makseq commented 6 years ago

Try to use tf.layers.batch_normalization(x, training=self.is_training, renorm=True), with renorm=True.

gabrieldemarmiesse commented 5 years ago

We'll close this issue and track the bug with #11927.