tensorflow / hub

A library for transfer learning by reusing parts of TensorFlow models.
https://tensorflow.org/hub
Apache License 2.0
3.48k stars 1.66k forks source link

Confusing numbers of trainable variables with hub.KerasLayer. #795

Closed sebastian-sz closed 3 years ago

sebastian-sz commented 3 years ago

When loading the model from TFHub as hub.KerasLayer and passing the parameter trainable=False one would expect all the layer related variables in the model to be non-trainable. This however seems not to be true upon further inspection.

Consider loading a simple model from TFHub:

import tensorflow as tf
import tensorflow_hub as hub

# Let's load EfficientNetV2 B0
URL = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b3/feature_vector/2"
layer = hub.KerasLayer(URL, trainable=False)
model = tf.keras.Sequential([
    tf.keras.Input((300, 300, 3)),
    layer,
])

Model summary will output 0 trainable parameters:

model.summary()

But if we iterate over variables we will see that some are set to be trainable

sum([1 for v in model.variables if v.trainable])
# 360
for v in model.variables:
    if v.trainable:
        print(v.name)
# efficientnetv2-b3/blocks_0/conv2d/kernel:0
# efficientnetv2-b3/blocks_0/tpu_batch_normalization/gamma:0
# ...

What is the true number of trainable variables in the above scenario?

MorganR commented 3 years ago

Thank you for the question @sebastian-sz. Although some of the variables may still be marked trainable on their own, you'll find that the model does not train them because they are part of a non-trainable layer.

As shown in the model summary, the model does not consider any weights to be trainable. This can also be confirmed by iterating over the model.trainable_weights attribute, which is empty.

For completeness, you can also try training the model and confirming that the weights do not change. Building on your sample code:

pre_train_values = {var.name: var.read_value() for var in model.variables if var.trainable}

sample_data = tf.stack([tf.zeros((300, 300, 3), dtype=tf.float32), tf.ones((300, 300, 3), dtype=tf.float32)])
sample_labels = tf.stack([tf.zeros((1536,)), tf.ones((1536,))])

model.compile(loss=tf.keras.losses.MeanSquaredError())
model.fit(sample_data, sample_labels, batch_size=2, epochs=10)

changed_values = {var.name: var.read_value() for var in model.variables if var.trainable and not tf.math.reduce_all(tf.math.equal(var.read_value(), pre_train_values[var.name]))}
if len(changed_values) == 0:
  print('No variables changed')

You should see that the loss remains the same throughout training, this prints 'No variables changed' at the end.

Does that resolve your question?

sebastian-sz commented 3 years ago

Yes it does! Thank you for the detailed answer!