Open b-remy opened 2 years ago
My proposal to solve that is to export the model as SavedModel in TF1, so that it can be reloaded as such in tf2. Demo for how to do that:
%tensorflow_version 1.x
import tensorflow_hub as hub
import tensorflow as tf
morph_model = hub.load('/content/model')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
random_normal = tf.placeholder(tf.float32, shape=[1, 16])
mag_auto = tf.placeholder(tf.float32, shape=[1,])
zphot = tf.placeholder(tf.float32, shape=[1,])
flux_radius = tf.placeholder(tf.float32, shape=[1,])
output = morph_model.signatures['default'](mag_auto=mag_auto,
zphot=zphot,
flux_radius=flux_radius,
random_normal=random_normal)
# Save with SavedModelBuilder
builder = tf.saved_model.Builder('saved-model-builder')
sig_def = tf.saved_model.predict_signature_def(
inputs={'random_normal': random_normal,
'mag_auto':mag_auto,
'zphot': zphot,
'flux_radius': flux_radius},
outputs={'output': output['default']})
builder.add_meta_graph_and_variables(
sess, tags=["serve"], signature_def_map={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
})
builder.save()
and then in tf2
%tensorflow_version 2.x
import tensorflow as tf
loaded = tf.saved_model.load('saved-model-builder/')
model = loaded.signatures['serving_default']
z = tf.random.normal(shape=[1,16])
with tf.GradientTape() as tape:
tape.watch(z)
res = model(mag_auto=24.*tf.ones([1]),
zphot=0.5*tf.ones([1]),
flux_radius=10.*tf.ones([1]),
random_normal=z)
full demo here: https://colab.research.google.com/drive/1uZroAiNufkNOGiSrAhVM7zr-1Mj8D1jN?usp=sharing
I was able to load models saved in https://github.com/McWilliamsCenter/deep_galaxy_models. The forward pass is working, I can run the repo's notebooks with TF 2.6.0 which I am currently using.
However, it seems that taking gradients through these models with
tf.GradientTape()
does not work... I get the following issueSo I looked for similar issues, and it seems that TF1 saved models cannot be fine-tuned with TF2: https://www.tensorflow.org/hub/model_compatibility#compatibility_of_hubmodule