Closed sebastian-sz closed 3 years ago
Thank you for the question @sebastian-sz. Although some of the variables may still be marked trainable on their own, you'll find that the model does not train them because they are part of a non-trainable layer.
As shown in the model summary, the model does not consider any weights to be trainable. This can also be confirmed by iterating over the model.trainable_weights
attribute, which is empty.
For completeness, you can also try training the model and confirming that the weights do not change. Building on your sample code:
pre_train_values = {var.name: var.read_value() for var in model.variables if var.trainable}
sample_data = tf.stack([tf.zeros((300, 300, 3), dtype=tf.float32), tf.ones((300, 300, 3), dtype=tf.float32)])
sample_labels = tf.stack([tf.zeros((1536,)), tf.ones((1536,))])
model.compile(loss=tf.keras.losses.MeanSquaredError())
model.fit(sample_data, sample_labels, batch_size=2, epochs=10)
changed_values = {var.name: var.read_value() for var in model.variables if var.trainable and not tf.math.reduce_all(tf.math.equal(var.read_value(), pre_train_values[var.name]))}
if len(changed_values) == 0:
print('No variables changed')
You should see that the loss remains the same throughout training, this prints 'No variables changed' at the end.
Does that resolve your question?
Yes it does! Thank you for the detailed answer!
When loading the model from TFHub as
hub.KerasLayer
and passing the parametertrainable=False
one would expect all the layer related variables in the model to be non-trainable. This however seems not to be true upon further inspection.Consider loading a simple model from TFHub:
Model summary will output 0 trainable parameters:
But if we iterate over variables we will see that some are set to be trainable
What is the true number of trainable variables in the above scenario?