Closed arisliang closed 1 year ago
Thanks for the suggestion. We think this could be a useful feature, so please feel free to open a PR to add an optional trainable
column to the summary (let's make it as narrow as possible since it's just a boolean flag).
This could be closed as a patch implementing it was already merged.
Did this patch get merged? I cannot see the show_trainable
argument in keras master
https://github.com/mfidabel/keras/blob/master/keras/utils/layer_utils.py
I tried using show_trainable
in a simple example from TF tutorial (with tf-nightly
), and it throws an error as shown below. Please check the [gist]() for a reference.
TypeError: plot_model(https://colab.research.google.com/gist/jvishnuvardhan/59688cb7063772ff9e8a3827ac946044/untitled1174.ipynb) got an unexpected keyword argument 'show_trainable'
@arisliang,
The related PR patch gets merged and I can see show_trainable argument in keras master
.
https://github.com/mfidabel/keras/blob/master/keras/utils/layer_utils.py
Could you please take a look at the layer_utils.py file and let us know if you are facing a similar issue. Thank you!
@arisliang, The related PR patch gets merged and I can see show_trainable argument in keras
master
. https://github.com/mfidabel/keras/blob/master/keras/utils/layer_utils.pyCould you please take a look at the layer_utils.py file and let us know if you are facing a similar issue. Thank you!
I think he is trying to print which layers are trainable on the plot_model() visualization tool instead of the model.summary() function. AFAIK, this hasn't been implemented yet.
Added show_trainable
to plot_model
. It should look something like this:
Where 'T' means the layer is trainable and 'NT' means the layer isn't trainable
@arisliang & @mfidabel, The related PR #15459 patch gets merged and I can see show_trainable argument is available in keras master. https://github.com/mfidabel/keras/blob/master/keras/utils/layer_utils.py
And the another PR #17145 was also merged where model.summary(show_trainable=True) was implemented and plot_model(model, show_trainable=True). I tried with the sample code on tf-nightly and was able to fetch the required output. Kindly find the gist of it here.
import tensorflow as tf
from tensorflow import keras
model = tf.keras.Sequential(
[
tf.keras.Input(shape=(230, 230, 3)),
tf.keras.layers.Conv2D(3, 1),
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Dense(10),
]
)
model.get_layer(index=0).trainable = False
tf.keras.utils.plot_model(
model,
to_file="model_trainable.png",
show_trainable=True,
)
@tilakrayal Yes, It is already merged. I think you can close the issue.
This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.
Closing as stale. Please reopen if you'd like to work on this further.
System information.
TensorFlow version (you are using): 2.6 Are you willing to contribute it (Yes/No) : No
Describe the feature and the current behavior/state.
For GAN style models, whether the parameter is trainable correctly is important to model training correctness. Therefore it would be useful to be able to include that information in the summary and model plots, if possible.
Will this change the current api? How? Should not change
model.summary
function. May add ashow_trainable
flag for plot_model function.Who will benefit from this feature? Users who develop GAN style models, or when layer's trainable parameter is important for the model correctness.