b-remy / gems

GEnerative Morphology for Shear
MIT License
0 stars 0 forks source link

Using `deep_galaxy_models` in our framework #12

Open b-remy opened 2 years ago

b-remy commented 2 years ago

I was able to load models saved in https://github.com/McWilliamsCenter/deep_galaxy_models. The forward pass is working, I can run the repo's notebooks with TF 2.6.0 which I am currently using.

However, it seems that taking gradients through these models with tf.GradientTape() does not work... I get the following issue

AttributeError: 'NoneType' object has no attribute 'outer_context'

So I looked for similar issues, and it seems that TF1 saved models cannot be fine-tuned with TF2: https://www.tensorflow.org/hub/model_compatibility#compatibility_of_hubmodule

EiffL commented 2 years ago

My proposal to solve that is to export the model as SavedModel in TF1, so that it can be reloaded as such in tf2. Demo for how to do that:

%tensorflow_version 1.x

import tensorflow_hub as hub
import tensorflow as tf

morph_model = hub.load('/content/model')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())

  random_normal = tf.placeholder(tf.float32, shape=[1, 16])
  mag_auto = tf.placeholder(tf.float32, shape=[1,])
  zphot = tf.placeholder(tf.float32, shape=[1,])
  flux_radius = tf.placeholder(tf.float32, shape=[1,])

  output = morph_model.signatures['default'](mag_auto=mag_auto,
                                zphot=zphot, 
                                flux_radius=flux_radius, 
                                random_normal=random_normal)

  # Save with SavedModelBuilder
  builder = tf.saved_model.Builder('saved-model-builder')
  sig_def = tf.saved_model.predict_signature_def(
      inputs={'random_normal': random_normal, 
              'mag_auto':mag_auto, 
              'zphot': zphot, 
              'flux_radius': flux_radius},
      outputs={'output': output['default']})
  builder.add_meta_graph_and_variables(
      sess, tags=["serve"], signature_def_map={
          tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
  })
  builder.save()

and then in tf2

%tensorflow_version 2.x

import tensorflow as tf

loaded = tf.saved_model.load('saved-model-builder/')
model = loaded.signatures['serving_default']

z = tf.random.normal(shape=[1,16])

with tf.GradientTape() as tape:
  tape.watch(z)
  res = model(mag_auto=24.*tf.ones([1]),
              zphot=0.5*tf.ones([1]), 
              flux_radius=10.*tf.ones([1]), 
              random_normal=z)

full demo here: https://colab.research.google.com/drive/1uZroAiNufkNOGiSrAhVM7zr-1Mj8D1jN?usp=sharing