tensorflow / hub

A library for transfer learning by reusing parts of TensorFlow models.
https://tensorflow.org/hub
Apache License 2.0
3.48k stars 1.66k forks source link

Vastly different output between tensorflow hub and keras for EfficientNetB0 due to possibly missing normalization #821

Closed Zahlii closed 1 year ago

Zahlii commented 2 years ago

I am attempting to migrate an existing Efficientnet model from keras.applications to the respective hub. However, I see vastly different behaviors in training and final performance, and thus I am trying to find the root cause.

Even with the correct preprocessing ([0-255] for keras.applications, [0-1] for tf hub), and having checked that the weights are exactly the same between the loaded feature extractors, I get different results.

However, as shown below, the keras.applications approach actually does a Rescaling(1/255.0), followed by a Normalization() layer before passing the inputs to the rest of the layers, where tf hub (as far as I understand) only receives the already rescaled outputs, so I am wondering if there might be a discrepancy due to different ways of normalizations.


import collections

import tensorflow as tf
import tensorflow_hub as tfh

img2 = tf.io.read_file("tests/test2.jpg")
img2 = tf.image.decode_jpeg(img2)
img2 = tf.image.convert_image_dtype(img2, tf.float32)
img2 = tf.image.resize(img2, (224, 224))
print(img2.numpy().min(), img2.numpy().max(), img2.numpy().mean())
# 0.0 1.0 0.14994633

img2 = tf.expand_dims(img2, axis=0)

model = tf.keras.applications.efficientnet.EfficientNetB0(include_top=False, input_shape=(224, 224, 3), pooling="avg")
model.trainable = False

# Note: each Keras Application expects a specific kind of input preprocessing.
# For EfficientNet, input preprocessing is included as part of the model (as a Rescaling layer),
# and thus tf.keras.applications.efficientnet.preprocess_input is actually a pass-through function.
# EfficientNet models expect their inputs to be float tensors of pixels with values in the [0-255] range.

print(model(img2 * 255.0, training=False))
# [[-0.1674832   0.06747124 -0.00998851 ... -0.09722453 -0.05333998
#   -0.09630407]], shape=(1, 1280), dtype=float32)

inp = tf.keras.layers.Input(shape=(224, 224, 3))
layer = tfh.KerasLayer("https://tfhub.dev/tensorflow/efficientnet/b0/feature-vector/1", trainable=False)(inp)

model_hub = tf.keras.models.Model(inputs=inp, outputs=layer)
print(model_hub(img2, training=False))
# tf.Tensor(
# [[-0.1534781  -0.00044355  0.02527641 ...  0.03095446 -0.07474983
#   -0.12997748]], shape=(1, 1280), dtype=float32)

weights = collections.defaultdict(lambda: collections.defaultdict(dict))

for w in model.non_trainable_weights:
    m = w.numpy().mean()
    weights[tuple(w.shape)][m][f"{w.name}_KERAS"] = w.numpy()

for w in model_hub.non_trainable_weights:
    m = w.numpy().mean()
    weights[tuple(w.shape)][m][f"{w.name}_HUB"] = w.numpy()

for s, data in weights.items():
    for mn, data2 in data.items():
        n = len(data2)
        has_hub = any("HUB" in name for name in data2.keys())
        has_keras = any("KERAS" in name for name in data2.keys())
        if n % 2 != 0 or not has_hub or not has_keras:
            print(s, mn, data2)

# (3,) 0.449 {'normalization/mean:0_KERAS': array([0.485, 0.456, 0.406], dtype=float32)}
# (3,) 0.226 {'normalization/variance:0_KERAS': array([0.229, 0.224, 0.225], dtype=float32)}
# () 0.0 {'normalization/count:0_KERAS': 0}
UsharaniPagadala commented 2 years ago

@Zahlii Could you please refer the similar issues #42506 and 60251715 and let us know if this helps.Thanks

Zahlii commented 2 years ago

@UsharaniPagadala as you can see already in the code, I stumbled acorss both of these issues previously. I am normalizing accordingly (range 0 - 255 for keras, 0 - 1 for hub); and I am not even dealing with the classification variant here, just the feature vectors (which in both cases I assume return values without activation).

alenarepina commented 2 years ago

Please follow the tutorial https://www.tensorflow.org/hub/tutorials/tf2_image_retraining and let us know if you still see quality issues.

Zahlii commented 2 years ago

Copying the tutorial code leads to exactly the same (erroneous) outputs as in my example; see code snippet below.

The reason WHY I opened this ticket was because I saw a drop in performance (98% to roughly 93%) when switching from keras to tfhub (while following exactly the tutorial posted), while keeping all other hyperparameters the same (learning rate, dropout, train/test splits, etc). This, for me, indicates that there is some kind of underlying issue with the outputs provided by the (unrefined) efficient net layer on hub, hence my attempt at debugging if this is a reason of different weights, or some other problems.

import tensorflow as tf
import tensorflow_hub as hub

print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

model_name = "efficientnet_b0"

model_handle_map = {
  "efficientnet_b0": "https://tfhub.dev/tensorflow/efficientnet/b0/feature-vector/1",
}

model_image_size_map = {
  "efficientnet_b0": 224,
}

model_handle = model_handle_map.get(model_name)
pixels = model_image_size_map.get(model_name, 224)

print(f"Selected model: {model_name} : {model_handle}")

IMAGE_SIZE = (pixels, pixels)
print(f"Input size {IMAGE_SIZE}")

BATCH_SIZE = 16

do_fine_tuning = False
print("Building model with", model_handle)
model = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    hub.KerasLayer(model_handle, trainable=do_fine_tuning),
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()

img2 = tf.io.read_file("tests/test2.jpg")
img2 = tf.image.decode_jpeg(img2)
img2 = tf.image.convert_image_dtype(img2, tf.float32)
img2 = tf.image.resize(img2, (224, 224))
img2 = tf.expand_dims(img2, axis=0)

print(img2.numpy().min(), img2.numpy().max(), img2.numpy().mean())
print(model.predict(img2))
# [[-0.1534781  -0.00044355  0.02527641 ...  0.03095446 -0.07474983
#  -0.12997748]]
maringeo commented 2 years ago

Hi @Zahlii, the preprocessing for the TF Hub model is done as tf.data transform and it's hidden by default: if you go to the Setup the flowers dataset section of the Colab and click on Toggle code you will see the transforms.

Most of the transforms are only applied during training in order to generate more training data by augmenting the input images. But tf.keras.layers.Rescaling(1. / 255) is applied during inference too.

I'm not sure what dataset and metric you refer to in I saw a drop in performance (98% to roughly 93%) but adding tf.keras.layers.Rescaling(1. / 255) to the TF Hub model will likely improve the score (tf.keras.applications.efficientnet.EfficientNetB0 also uses this layer).

I can't repro the exact behavior since I am not certain what dataset and metric are used but could you try adding tf.keras.layers.Rescaling(1. / 255) and let us know if it helps?

Zahlii commented 2 years ago

Hi @maringeo , as I already stated above I am aware that the rescaling needs to be done prior to feeding the image to tfhub as per the image conventions. Also, rescaling will only be required when the input is in int format, i.e. 0...255. When using the tf.convert_image_dtype approach, this rescaling is already applied, hence adding a Rescaling() layer prior to the tfhub layer actually achieves nothing, because data will be completely off scale.

The code mentioned above is directly based on https://www.tensorflow.org/hub/common_signatures/images#input .

I already checked this by confirming the input data is correctly converted into the [0...1] range just prior to calling model.predict():

print(img2.numpy().min(), img2.numpy().max(), img2.numpy().mean())
# 0.0 1.0 0.14994633

To confirm, the following modification to above produces the following values, which are WAY further away from the keras applications values than even the hub implementation.

model = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    tf.keras.layers.experimental.preprocessing.Rescaling(1/255.0),
    hub.KerasLayer(model_handle, trainable=do_fine_tuning),
])
...
model.predict(...)
# [[-0.06889984 -0.09886284 -0.13236746 ... -0.19460136 -0.04948389
#   0.04994771]]

It seems to me that the keras applications might do some kind of channel-wise normalization, which potentially the hub version is not doing...

Zahlii commented 2 years ago

I tried checking in the code which is supposed to be used during training and I can find the following traces: https://github.com/tensorflow/models/blob/30e6e03f66efad4e43f1b98ec8680451f5a86a72/official/vision/image_classification/preprocessing.py#L168

Note that they essentially call


MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)

# image already supposed to be in the [0-1] range before, as they call tf.image.convert_dtype() before
# yet, they STILL act on it as if it was a 255 scaled image?
image = (image - MEAN_RGB) / STDDEV_RGB

(See https://github.com/tensorflow/models/blob/30e6e03f66efad4e43f1b98ec8680451f5a86a72/official/vision/image_classification/efficientnet/efficientnet_model.py#L361 )

Compare to the normal preprocessing layer inside keras (which is used by applications): https://github.com/keras-team/keras/blob/2c48a3b38b6b6139be2da501982fd2f61d7d48fe/keras/layers/preprocessing/normalization.py#L258

MEAN = [0.485, 0.456, 0.406]
STDDEV = [0.229, 0.224, 0.225]
# note the similarity to the one from model garden
image = (image - MEAN) / max(sqrt(STDDEV), epsilon)
# here, they are dividing by the square root, NOT the actual numbers itself !

Now, it seems there are two main differences:

a) Normalization in keras is done on the sqrt() of the variance with some epsilon, as opposed to model garden which simply divides by it b) Normalization in keras is done in same scale [0-1], while it is divided by a much too big value in model garden.

Now, I do not know which is the "right" thing to do (most likely the one similar to what was used to obtain the weights on the first hand), but there seems to be a HUGE difference in these two approaches.

maringeo commented 2 years ago

Thank you @Zahlii, I overlooked the call to convert_image_dtype. I reached out to the Model Garder authors to comment on your analysis in https://github.com/tensorflow/hub/issues/821#issuecomment-969122282.

saberkun commented 2 years ago

Hi, Mingxing and I wrote the original efficientnet: https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/efficientnet/efficientnet_builder.py#L31 and the model garden version should reproduce the efficientnet training. The efficientnet is trained with

MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]

https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/efficientnet/main.py#L369

For TF-Hub export, here is the code the TF-Hub is exported: https://github.com/tensorflow/models/blob/master/official/vision/image_classification/efficientnet/tfhub_export.py#L38

The inputs are expected to be [0., 1.] Then we simply rescale it to [0., 255.] as x = image_input * 255.0.

We don't know if keras implementation is the same as the original efficientnet.

Zahlii commented 2 years ago

Hi @saberkun ,,

Thanks for clarifying this. To summarize my understanding: tfhub/model-garden expects the input to be [0...1], which then gets converted to [0...255] as a first step, and then normalized using the above mentioned (x - MEAN_RGB) / STDDEV_RGB approach; while keras itself expects the input to be [0...255], which then gets normalized to [0...1] (Rescaling layer) and normalized using (x - MEAN_RGB/255.0) / (sqrt(STDDEV_RGB/255.0)) (normalization layer). Since all weights come from the same checkpoint, this seems to me that the issue may be with the keras implementation, even though performance for my proprietary dataset was higher with the "incorrect" approach?

saberkun commented 2 years ago

@Zahlii I did further investigation as the following:

image = tf.ones((1, 224, 224, 3), dtype=tf.float32) / 2.0
inp = tf.keras.layers.Input(shape=(224, 224, 3))
layer = tfh.KerasLayer("https://tfhub.dev/tensorflow/efficientnet/b0/feature-vector/1", trainable=False)(inp)

model_hub = tf.keras.models.Model(inputs=inp, outputs=layer)
print(model_hub(image, training=False))
tf.compat.v1.disable_eager_execution()
module = tfh.Module("https://tfhub.dev/google/efficientnet/b0/feature-vector/1")
image = tf.ones((1, 224, 224, 3), dtype=tf.float32) / 2.0
outputs = module(image)  # Logits with shape [batch_size, num_classes].
init = tf.compat.v1.global_variables_initializer()

with tf.compat.v1.Session() as sess:
    sess.run(init)
    print(sess.run(outputs))

So far we can confirm that TF2 savedmodel has the same output as the TF1 hub module, as I exported them from the same released efficientnet checkpoints. Yes, according to your findings, keras implementation is different from our original efficientnet code.

In terms of why your performance drops from 98% to 93%, do you enable training for the backbone? I am a bit worried about the batch norm behavior. @maringeo are we able to test if the hub module batch norm works correctly? We implemented the global/group batch norm in the efficientnet (likely keras implementation does not have it) but I feel it does not matter for the B0 model.

Zahlii commented 2 years ago

I will try to find some time to provide a side by side example/code for training showcasing the performance changes; however this may take a while due to work priorities.

Zahlii commented 2 years ago

I managed to (somewhat!) reproduce the performance gap (Note the differences in loss and val_loss between the two...) on the tf flower dataset below. Interestingly, the hub model trains a little bit faster? Small note: If I switch SGD to Adam, the differences become much bigger, and the Adam model converges to a worse result (Most likely because the internal Adam variables are randomly initialized and do not fit to the pretrained weights?)


from functools import partial
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

(train_ds, val_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:]'],
    with_info=True,
    as_supervised=True,
)
n_train = train_ds.cardinality().numpy()
n_val = val_ds.cardinality().numpy()

# image is uint8 [0..255]
# label is 5 classes, numerically encoded
n_classes = metadata.features['label'].num_classes

def resize(img, label, multiply=False):
    # images are now in [0...1] range if multiply=False, else in [0..255]
    img = tf.image.resize(tf.image.convert_image_dtype(img, tf.float32), [224, 224])
    if multiply:
        img = img * 255.0
    return img, tf.one_hot(label, depth=n_classes)

batch_size = 32
n_steps_train = np.ceil(n_train / batch_size)
n_steps_val = np.ceil(n_val / batch_size)

def prep_dataset(ds, multiply=False, batch_size=8):
    ds = ds.map(partial(resize, multiply=multiply), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.cache()
    ds = ds.shuffle(buffer_size=1000)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
    return ds.repeat()

# START APPLICATIONS
train_keras = prep_dataset(train_ds, multiply=True, batch_size=batch_size)
val_keras = prep_dataset(val_ds, multiply=True, batch_size=batch_size)

model = tf.keras.Sequential([
    tf.keras.applications.efficientnet.EfficientNetB0(
        include_top=False, input_shape=(224, 224, 3), pooling="avg"
    ),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(n_classes, activation=None)
])
model.summary()

def train(m, ds_train, ds_val):
    m.compile(
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9),
        loss=tf.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
        metrics=["accuracy"]
    )

    m.fit(
        ds_train, 
        steps_per_epoch=n_steps_train, 
        validation_data=ds_val, 
        validation_steps=n_steps_val, 
        epochs=50, 
        callbacks=[
            tf.keras.callbacks.ReduceLROnPlateau(patience=7, verbose=1),
            tf.keras.callbacks.EarlyStopping(patience=23, verbose=1),
        ]
    )

train(model, train_keras, val_keras)
# with SGD
# Epoch 49/50
# 92/92 [==============================] - 31s 333ms/step - loss: 0.4148 - accuracy: 1.0000 - val_loss: 0.4950 - val_accuracy: 0.9619
# Epoch 50/50
# 92/92 [==============================] - 31s 333ms/step - loss: 0.4157 - accuracy: 0.9997 - val_loss: 0.4948 - val_accuracy: 0.9605

# with Adam
# Epoch 49/50
# 92/92 [==============================] - 31s 337ms/step - loss: 0.3986 - accuracy: 0.9997 - val_loss: 0.6604 - val_accuracy: 0.8747
# Epoch 50/50
# 92/92 [==============================] - 31s 337ms/step - loss: 0.3987 - accuracy: 0.9990 - val_loss: 0.6603 - val_accuracy: 0.8733

# SGD, NO label smoothing
# Epoch 30/50
# 92/92 [==============================] - 31s 338ms/step - loss: 0.0103 - accuracy: 0.9986 - val_loss: 0.1377 - val_accuracy: 0.9619
# Epoch 31/50
# 92/92 [==============================] - 31s 338ms/step - loss: 0.0084 - accuracy: 0.9980 - val_loss: 0.1378 - val_accuracy: 0.9619
# Epoch 00031: early stopping

# Adam, lo learning rate
# Epoch 47/50
# 92/92 [==============================] - 30s 330ms/step - loss: 0.3937 - accuracy: 1.0000 - val_loss: 0.4719 - val_accuracy: 0.9646
# Epoch 48/50
# 92/92 [==============================] - 30s 330ms/step - loss: 0.3944 - accuracy: 0.9990 - val_loss: 0.4719 - val_accuracy: 0.9619
# Epoch 00048: early stopping

# START TFHUB
import tensorflow_hub as tfh

train_hub = prep_dataset(train_ds, multiply=False, batch_size=batch_size)
val_hub = prep_dataset(val_ds, multiply=False, batch_size=batch_size)

model_hub = tf.keras.Sequential([
    tf.keras.layers.InputLayer((224, 224, 3)),
    tfh.KerasLayer(
        "https://tfhub.dev/tensorflow/efficientnet/b0/feature-vector/1", trainable=True
    ),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(metadata.features['label'].num_classes, activation=None)
])
model_hub.summary()

train(model_hub, train_hub, val_hub)

# with SGD
# Epoch 49/50
# 92/92 [==============================] - 29s 316ms/step - loss: 0.5472 - accuracy: 0.9993 - val_loss: 0.6250 - val_accuracy: 0.9619
# Epoch 50/50
# 92/92 [==============================] - 29s 316ms/step - loss: 0.5476 - accuracy: 0.9986 - val_loss: 0.6246 - val_accuracy: 0.9659

# with Adam
# Epoch 49/50
# 92/92 [==============================] - 29s 318ms/step - loss: 0.4883 - accuracy: 0.9976 - val_loss: 0.7955 - val_accuracy: 0.8569
# Epoch 50/50
# 92/92 [==============================] - 29s 319ms/step - loss: 0.4863 - accuracy: 0.9990 - val_loss: 0.7759 - val_accuracy: 0.8624

# with SGD, NO label smoothing
# Epoch 39/50
# 92/92 [==============================] - 29s 317ms/step - loss: 0.1466 - accuracy: 0.9976 - val_loss: 0.2873 - val_accuracy: 0.9605
# Epoch 40/50
# 92/92 [==============================] - 29s 318ms/step - loss: 0.1398 - accuracy: 0.9986 - val_loss: 0.2870 - val_accuracy: 0.9605
# Epoch 00040: early stopping

# Adam, low lr
# Epoch 49/50
# 92/92 [==============================] - 29s 318ms/step - loss: 0.5214 - accuracy: 0.9997 - val_loss: 0.5890 - val_accuracy: 0.9673
# Epoch 50/50
# 92/92 [==============================] - 29s 316ms/step - loss: 0.5213 - accuracy: 0.9990 - val_loss: 0.5892 - val_accuracy: 0.9659