keras-team / keras

Deep Learning for humans
Apache License 2.0
61.11k stars 19.36k forks source link

The .fit() method passes unsized batches to a custom loss function. #19732

Open Ybisalt opened 2 weeks ago

Ybisalt commented 2 weeks ago

The .fit() method passes unsized batches to a custom loss function if the dataset size is not a multiple of the batch size. This happens twice, then everything goes fine. But if you need to use the batch size to normalize or reshape data, then an error occurs.

import numpy as np
from tensorflow import keras
import keras.backend as K

inp_data = np.full((1000, 8), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
out_data = np.full((1000, 4), [0.2, 0.4, 0.6, 0.8])

inp = keras.layers.Input((8,))
x = keras.layers.Dense(32, activation='relu')(inp)
out = keras.layers.Dense(4, activation='sigmoid')(x)

model = keras.Model(inp, out, name='My_model')

def my_loss(x, y):
    if x.shape[0] is None: print("Batch size = None!", x.shape)
    return K.sum(K.square(x-y), axis=0)   # /x.shape[0] <---ERROR, because None!

#model.compile(optimizer='adam', loss='mean_squared_error')
#log =, out_data, epochs=3, batch_size=50)   # OK! 1000/50=20
#log =, out_data, epochs=3, batch_size=30)   # OK! 1000/30=33.33~ -> 34

model.compile(optimizer='adam', loss=my_loss)
#log =, out_data, epochs=3, batch_size=50)   # OK!
log =, out_data, epochs=3, batch_size=30)    # Batch size = None!

Epoch 1/3
Batch size = None! (None, 4)
Batch size = None! (None, 4)
34/34 [==============================] - 1s 2ms/step - loss: 0.0961
Epoch 2/3
34/34 [==============================] - 0s 2ms/step - loss: 0.0268
Epoch 3/3
34/34 [==============================] - 0s 2ms/step - loss: 0.0055
SuryanarayanaY commented 1 week ago

Hi @Ybisalt ,

I have tested the given code with Keras 3.3.3v and it executes fine. Please note that I have changed the code K.sum to keras.ops.sum and same for K.square also. Please refer to attached gist.

Ybisalt commented 1 week ago

I have tested the given code with Keras 3.3.3v and it executes fine. Please note that I have changed the code K.sum to keras.ops.sum and same for K.square also.

No! Same problem here. You didn't notice "Batch size = None! (None, 4)" line in the last output log. The number of batches without size depends on the order in which the method is called.

Try just one run of the fit() method (comment out the other fit lines): log =, out_data, epochs=3, batch_size=30) # Batch size = None!

My gist

fchollet commented 1 week ago

x.shape is the "static shape" of x. It is often not a number. It can be None. If you want the actual number value, you must use keras.ops.shape(x), e.g.

keras.ops.sum(keras.ops.square(x - y), axis=0) / keras.ops.shape(x)[0]