tensorflow / lucid

A collection of infrastructure and tools for research in neural network interpretability.
Apache License 2.0
4.65k stars 655 forks source link

Example of loading a tensorflow model from scratch #223

Open mencia opened 4 years ago

mencia commented 4 years ago

Back in 2018 there was a discussion on how to load your own tensorflow model https://github.com/tensorflow/lucid/issues/34. Later a new way of doing it was suggested https://github.com/tensorflow/lucid/pull/152.

It would be very helpful if there was a minimal example where: 1) a model is built, 2) trained, 3) saved and 4) visualized. I figured out the first three steps, but I am stuck on the fourth. Below I will show the first three steps:

1. Build a VAE model

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tfd = tf.contrib.distributions

class VAE:

    """VAE"""

    def __init__(self, data):

        self.data = data
        self.loss = self.build_loss()
        self.sample = self.sample()

    def make_encoder(self, data, code_size):
      x = tf.layers.flatten(data)
      x = tf.layers.dense(x, 200, tf.nn.relu)
      x = tf.layers.dense(x, 200, tf.nn.relu)
      loc = tf.layers.dense(x, code_size)
      scale = tf.layers.dense(x, code_size, tf.nn.softplus)
      return tfd.MultivariateNormalDiag(loc, scale)

    def make_prior(self, code_size):
      loc = tf.zeros(code_size)
      scale = tf.ones(code_size)
      return tfd.MultivariateNormalDiag(loc, scale)

    def make_decoder(self, code, data_shape):
      x = code
      x = tf.layers.dense(x, 200, tf.nn.relu)
      x = tf.layers.dense(x, 200, tf.nn.relu)
      logit = tf.layers.dense(x, np.prod(data_shape))
      logit = tf.reshape(logit, [-1] + data_shape)
      return tfd.Independent(tfd.Bernoulli(logit), 2)

    def build_loss(self):
        """We sample the posterior to input the decoder"""
        prior = self.make_prior(code_size=2)
        posterior = self.make_encoder(self.data, code_size=2)
        code = posterior.sample()
        likelihood = self.make_decoder(code, [28, 28]).log_prob(self.data)
        divergence = tfd.kl_divergence(posterior, prior)
        elbo = tf.reduce_mean(likelihood - divergence)
        return -elbo

    def sample(self):
        """Decodes a random code"""
        prior = self.make_prior(code_size=2)
        return self.make_decoder(prior.sample(10), [28, 28]).mean()

2. Train the model

mnist = input_data.read_data_sets('MNIST_data/')
data = tf.placeholder(tf.float32, [None, 28, 28])
model = VAE(data)
loss = model.loss
optimize = tf.train.AdamOptimizer(0.001).minimize(loss)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(2):
        for _ in range(60):
            feed_dict = {data: mnist.train.next_batch(100)[0].reshape([-1, 28, 28])}
            sess.run(optimize, feed_dict)
            saver.save(sess, './logging/model_final')

3. Load the trained model and save it for lucid visualization

from lucid.modelzoo.vision_models import Model

with tf.Graph().as_default() as graph, tf.Session() as sess:

    path = './logging/'
    ckpt_state = tf.train.get_checkpoint_state(path)
    data = tf.placeholder(tf.float32, [None, 28, 28], name='images')
    model = VAE(data)
    saver = tf.train.Saver()
    saver.restore(sess, ckpt_state.model_checkpoint_path)

    Model.save("saved_model.pb",  
     input_name='images', 
     output_names=[graph.as_graph_def().node[-1].name], 
     image_shape=[28,28],
     image_value_range=[0,1])

4. Visualize

I get an error when trying to visualize it.

from lucid.modelzoo.vision_models import Model
import lucid.optvis.render as render

model = Model.load("saved_model.pb")
_ = render.render_vis(model, "dense_9/kernel:0")

The raised error:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~/env_py36/lib/python3.6/site-packages/tensorflow/python/framework/importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    426         results = c_api.TF_GraphImportGraphDefWithResults(
--> 427             graph._c_graph, serialized, options)  # pylint: disable=protected-access
    428         results = c_api_util.ScopedTFImportGraphDefResults(results)

InvalidArgumentError: Input 0 of node import/save/Assign was passed float from import/dense/bias:0 incompatible with expected float_ref.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-6-dc71ccdc8e67> in <module>
      1 import lucid.optvis.render as render
----> 2 _ = render.render_vis(model, "dense_9/kernel:0")

~/env_py36/lib/python3.6/site-packages/lucid/optvis/render.py in render_vis(model, objective_f, param_f, optimizer, transforms, thresholds, print_objectives, verbose, relu_gradient_override, use_fixed_seed)
     93 
     94     T = make_vis_T(model, objective_f, param_f, optimizer, transforms,
---> 95                    relu_gradient_override)
     96     print_objective_func = make_print_objective_func(print_objectives, T)
     97     loss, vis_op, t_image = T("loss"), T("vis_op"), T("input")

~/env_py36/lib/python3.6/site-packages/lucid/optvis/render.py in make_vis_T(model, objective_f, param_f, optimizer, transforms, relu_gradient_override)
    175     with gradient_override_map({'Relu': redirected_relu_grad,
    176                                 'Relu6': redirected_relu6_grad}):
--> 177       T = import_model(model, transform_f(t_image), t_image)
    178   else:
    179     T = import_model(model, transform_f(t_image), t_image)

~/env_py36/lib/python3.6/site-packages/lucid/optvis/render.py in import_model(model, t_image, t_image_raw, scope, input_map)
    255     t_image_raw = t_image
    256 
--> 257   model.import_graph(t_image, scope=scope, forget_xy_shape=True, input_map=input_map)
    258 
    259   def T(layer):

~/env_py36/lib/python3.6/site-packages/lucid/modelzoo/vision_base.py in import_graph(self, t_input, scope, forget_xy_shape, input_map)
    198       final_input_map.update(input_map)
    199     tf.import_graph_def(
--> 200         self.graph_def, final_input_map, name=scope)
    201     self.post_import(scope)
    202 

~/env_py36/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

~/env_py36/lib/python3.6/site-packages/tensorflow/python/framework/importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    429       except errors.InvalidArgumentError as e:
    430         # Convert to ValueError for backwards compatibility.
--> 431         raise ValueError(str(e))
    432 
    433     # Create _DefinedFunctions for any imported functions.

ValueError: Input 0 of node import/save/Assign was passed float from import/dense/bias:0 incompatible with expected float_ref.

Could someone provide a working minimal example please?