keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.94k stars 19.46k forks source link

Keras predict/predict_on_batch giving different answers than predict_step/__call__() #19403

Open cmdlinebeep opened 7 months ago

cmdlinebeep commented 7 months ago

From my understanding, all four of these methods: predict, predict_on_batch, predict_step, and a direct forward pass through the model (e.g. model(x, training=False) or __call__()) should all give the same results, some are just more efficient than others in how they handle batches of data versus one sample.

But I am actually getting different results on an image super-resolution (upscaling) task I'm working on:

for lowres, _ in val.take(1):
    # Get a randomly cropped region of the lowres image for upscaling
    lowres = tf.image.random_crop(lowres, (150, 150, 3))  # uint8

    # Need to add a dummy batch dimension for the predict step    
    model_inputs = tf.expand_dims(lowres, axis=0)  # (1, 150, 150, 3), uint8

    # And convert the uint8 image values to float32 for input to the model
    model_inputs = tf.cast(model_inputs, tf.float32)  # float32

    preds = model.predict_on_batch(model_inputs)
    min_val = tf.reduce_min(preds).numpy()
    max_val = tf.reduce_max(preds).numpy()
    print("Min value: ", min_val)
    print("Max value: ", max_val)

    preds = model.predict(model_inputs)
    min_val = tf.reduce_min(preds).numpy()
    max_val = tf.reduce_max(preds).numpy()
    print("Min value: ", min_val)
    print("Max value: ", max_val)

    preds = model.predict_step(model_inputs)
    min_val = tf.reduce_min(preds).numpy()
    max_val = tf.reduce_max(preds).numpy()
    print("Min value: ", min_val)
    print("Max value: ", max_val)

    preds = model(model_inputs, training=False)  # __call__()
    min_val = tf.reduce_min(preds).numpy()
    max_val = tf.reduce_max(preds).numpy()
    print("Min value: ", min_val)
    print("Max value: ", max_val)

Prints:

Min value:  -6003.622
Max value:  5802.6826

Min value:  -6003.622
Max value:  5802.6826

Min value:  -53.7696
Max value:  315.1499

Min value:  -53.7696
Max value:  315.1499

Both predict_step and a direct forward pass (__call__()) give the "correct" answers as defined by the upscaled images look correct.

I'm happy to share more details on the model if that's helpful, but for now I thought I'd just leave it at this to not overcomplicate the question. At first I wondered if these methods had different results based on training/inference modes, but my model doesn't use any BatchNorm or Dropout layers, so that shouldn't make a difference here. It's completely composed of: Conv2D, Add, tf.nn.depth_to_space (pixel shuffle), and Rescaling layers. That's it. It also doesn't use any subclassing or override any methods, just uses keras.Model(inputs, outputs).

FWIW, I'm using Keras 2, not 3, if that matters. Couldn't get my code to work using Keras 3.

Any ideas why these prediction methods would give different answers?

This very well could just be a misunderstanding on my part, but from my reading of the documentation, all these methods should give the same answer, so I'm inclined to think this is a bug.

UPDATE 1: I've been able to create a minimally reproducible example where you can see the issue. Please see: https://www.kaggle.com/code/quackaddict7/really-minimum-reproducible-example

I initially couldn't reproduce the problem in a minimal example. I eventually added back in a dataset, batching, data augmentation, training, model file saving/restoring, and eventually discovered the issue is just GPU vs. CPU! So I took all that back out for my minimal example. If you run the Kaggle notebook you'll see that on CPU, all four methods give the same inference answer with randomly initialized weights. But if you change to P100 GPU, predict/predict_on_batch differ from predict_step/forward pass (__call__).

So I guess at this point, my question is, why are CPU vs. GPU results different here? Is this a bug?

UPDATE 2: Please see my cross-post to Stack Overflow here, but this does seem to be a bug, particularly related to the ReLU function. https://stackoverflow.com/questions/78242094/keras-predict-predict-on-batch-giving-different-answers-than-predict-step-call

innat commented 7 months ago

cc @sachinprasadhs

Below is the reproducible code in keras 3. Also, removing the relu activation works in keras 2, please check this.

import tensorflow as tf
import numpy as np
import keras
from keras import layers

def ResBlock(inputs):
    x = layers.Conv2D(64, 3, padding="same", activation="relu")(inputs)
    x = layers.Conv2D(64, 3, padding="same")(x)
    x = layers.Add()([inputs, x])
    return x

def Upsampling(inputs, factor=2, **kwargs):
    x = layers.Conv2D(
        64 * (factor ** 2), 3, padding="same", **kwargs
    )(inputs)
    x = layers.Lambda(lambda x:tf.nn.depth_to_space(x,factor))(x)
    x = layers.Conv2D(
        64 * (factor ** 2), 3, padding="same", **kwargs
    )(x)
    x = layers.Lambda(lambda x:tf.nn.depth_to_space(x,factor))(x)
    return x

def make_model(num_filters, num_of_residual_blocks):
    input_layer = layers.Input(shape=(None, None, 3))
    x = layers.Rescaling(scale=1.0 / 255)(input_layer)
    x = x_new = layers.Conv2D(num_filters, 3, padding="same")(x)

    for i in range(num_of_residual_blocks):
        x_new = ResBlock(x_new)

    x_new = layers.Conv2D(num_filters, 3, padding="same")(x_new)
    x = layers.Add()([x, x_new])
    x = Upsampling(x)
    x = layers.Conv2D(3, 3, padding="same")(x)

    output_layer = layers.Rescaling(scale=255)(x)
    return keras.Model(input_layer, output_layer)

model = make_model(num_filters=64, num_of_residual_blocks=16)
lowres = tf.random.uniform(
    shape=(150, 150, 3), 
    minval=0, 
    maxval=256, 
    dtype='float32'
)
model_inputs = tf.expand_dims(lowres, axis=0)
predict_out = model.predict(model_inputs)
predict_on_batch_out = model.predict_on_batch(model_inputs)
predict_call_out = model(model_inputs, training=False).numpy()
predict_step_out = model.predict_step(model_inputs).numpy()
print(
    predict_out.shape, 
    predict_on_batch_out.shape, 
    predict_call_out.shape, 
    predict_step_out.shape
)
# not OK
np.testing.assert_allclose(
    predict_out,
    predict_on_batch_out,
    1e-5, 1e-5
)

# not OK
np.testing.assert_allclose(
    predict_on_batch_out,
    predict_call_out,
    1e-5, 1e-5
)

# OK
np.testing.assert_allclose(
    predict_call_out,
    predict_step_out,
    1e-5, 1e-5
)
fchollet commented 6 months ago

When you are calling __call__ or predict_step, you are using eager execution by default. When you are calling predict or predict_on_batch you are using a compiled function.

So it sounds like the problem is at the intersection of compiled function usage + GPU usage + relu?

fchollet commented 6 months ago

My commendation here would be try with another backend, e.g. torch or JAX. It is likely to be a TF specific issue.

github-actions[bot] commented 6 months ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

cmdlinebeep commented 6 months ago

I apologize, but it's not really clear what's expected of me.

We have the reproducible example in Keras 2 (see above). And innat has confirmed this is still an issue in Keras 3 (notebook here )

But to change to a different backend, we'd have to replace all the tf.nn.depth_to_space() functions out with PyTorch equivalents. In my opinion that'd be changing too many variables (now the backend changed AND the specific functions/layers used), and could just muddy the waters.

It could also be a PyTorch issue, but haven't we done enough to show it definitively is a CPU/GPU issue?

shoaib42 commented 2 months ago

Can add that I'm seeing this too. Using Tensorflow, CPU and all activation=tanh

P = model.predict(X)
P_on_batch = model.predict_on_batch(X)
P_call_out = model(X, training=False).numpy()
P_step_out = model.predict_step(X).numpy()

# Matches on the following two
(~(P_call_out == P_on_batch)).sum() # Matches!
(~(P_step_out == P_on_batch)).sum() # Matches!

# predict() mismatches 
(~(P == P_on_batch)).sum() # not zero, about 3.11% mismatch
# fails on
np.testing.assert_allclose(
    P,
    P_on_batch,
    1e-10, 1e-10
)