keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.05k stars 19.48k forks source link

model.fit() doesn't update loss values properly with a custom loss function in Keras 3 #20496

Open youmemonk opened 1 day ago

youmemonk commented 1 day ago

Hi,

I recently upgraded from Keras 2 to Keras 3 and noticed some strange behavior when using a custom loss function with model.fit(). The model trains without throwing any errors, but the loss values during training stay almost constant across epochs, even when the model’s predictions are changing.

This problem doesn’t happen if I use a built-in loss function like mse, and the same custom loss function worked perfectly fine in Keras 2.

Steps to Reproduce

example that shows the issue:

import tensorflow as tf  
from tensorflow import keras  
import numpy as np  

# Custom loss function  
def custom_loss(y_true, y_pred):  
    diff = y_true - y_pred  
    return tf.reduce_mean(diff ** 2)  

# Dummy dataset  
X = np.random.rand(100, 10)  
y = np.random.rand(100, 1)  

# Simple model  
model = keras.Sequential([  
    keras.layers.Dense(32, activation='relu', input_shape=(10,)),  
    keras.layers.Dense(1)  
])  

# Compile with custom loss  
model.compile(optimizer='adam', loss=custom_loss)  

# Fit the model  
history = model.fit(X, y, epochs=10, verbose=1)  

When I run this code, the loss value printed during training doesn’t change much:

Epoch 1/10  
loss: 0.2567  
Epoch 2/10  
loss: 0.2567  
...  

However, if I check the model predictions before and after training, I can see that they’re changing.

y_pred_before = model.predict(X)  
model.fit(X, y, epochs=1, verbose=0)  
y_pred_after = model.predict(X)  
assert not np.allclose(y_pred_before, y_pred_after)  

Expected Behavior

The loss values should decrease during training as the model learns.

Environment

Using built-in loss functions (e.g., keras.losses.MeanSquaredError()) works fine.

Could this be related to changes in graph execution or how custom loss functions are handled in Keras 3?

Let me know if you need any more details to debug this. Thanks!

fchollet commented 1 day ago

I'm not able to reproduce, here's what you get with the latest Keras version:

Epoch 1/10
4/4 ━━━━━━━━━━━━━━━━━━━━ 2s 16ms/step - loss: 0.4584
Epoch 2/10
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.3704 
Epoch 3/10
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.2906  
Epoch 4/10
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2306
Epoch 5/10
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.1722
Epoch 6/10
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.1350  
Epoch 7/10
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.1156 
Epoch 8/10
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0997  
Epoch 9/10
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.0856  
Epoch 10/10
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0880

Link: https://colab.research.google.com/drive/1vSA_iBg-6AW4hFQdXHJf3HuxwIZ6GUTA#scrollTo=DxzDg4lbkyUc