apple / tensorflow_macos

TensorFlow for macOS 11.0+ accelerated using Apple's ML Compute framework.
Other
3.67k stars 308 forks source link

NaN Training Loss After Introducing BatchNormalization() #231

Open konempty opened 3 years ago

konempty commented 3 years ago

I am not good at English, so I ask for your understanding in advance. It was trained normally before adding BatchNormalization() to the model. Also, when I run it to the CPU after I added it, the loss does not become nan. And code works fine on other devices, but weirdly only on the m1 mac. Below is my code.

from tensorflow.keras.layers import BatchNormalization, Dropout

def vanilla_model(): input = Input(shape=(imag_size, imag_size, 3)) conv1 = Conv2D(16, (3, 3), padding='same', activation='relu')(input) pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) batch1 = BatchNormalization()(pool1)

conv2 = Conv2D(32, (3, 3), padding='same', activation='relu')(batch1)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
batch2 = BatchNormalization()(pool2)

conv3 = Conv2D(64, (3, 3), padding='same', activation='relu')(batch2)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
batch3 = BatchNormalization()(pool3)

conv4 = Conv2D(128, (3, 3), padding='same', activation='relu')(batch3)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
batch4 = BatchNormalization()(pool4)

flatten = Flatten()(batch4)
dense1 = Dense(256, activation='relu')(flatten)
drop1 = Dropout(0.3)(dense1)
dense2 = Dense(64, activation='relu')(drop1)
drop2 = Dropout(0.4)(dense2)
output = Dense(5, activation='softmax')(drop2)

model = Model(inputs=input, outputs=output)

return model

model = vanilla_model() opt = Adam(lr=0.001, decay=1e-6)

model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

early_stopping = EarlyStopping(monitor='val_loss', patience=10) checkpoint_callback = ModelCheckpoint('multiclass_weights.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min')

history = model.fit( train_generator, steps_per_epoch = total // batch_size,
validation_data = valid_generator, epochs = 100, callbacks=[early_stopping, checkpoint_callback] )