allenai / bilm-tf

Tensorflow implementation of contextualized word representations from bi-directional language models
Apache License 2.0
1.62k stars 451 forks source link

problem with freezing graph #107

Open ohwe opened 6 years ago

ohwe commented 6 years ago

Hi, ELMo team. For TF models deployment in c++ I typically use freezing graphs (via graph_util.convert_variables_to_constants) to const GraphDef and then I have single .pb graphdef-file for applying model.

However, in this case a set of 'bilm/Variable_*' (being assigned through tf.assign) variables prevents me from freezing the model. These variables arise as init_states for LSTMs being passed between layers.

The question is why do you avoid using tf.nn.rnn_cell.MultiRNNCell in this case that seems to make things much simpler?

matt-peters commented 6 years ago

I haven't tried to use graph_util.convert_variables_to_constants so can't comment on the source of the error.

As far as the particular implementation details, I honestly can't remember, I wrote much of this code nearly two years ago :-/

limohanlmh commented 6 years ago

Hi, I have also encountered the similar problem freezing the graph for C++ deployment. To be specific, the elmo feature is the output node of the computation graph, after freezing the graph, the token id is the entry node of the graph. However, as I called the tensorflow C++ API to create a session, an error triggered, indicating that "Invalid argument: Input 0 of node bilm/Assign_7 was passed float from bilm/Variable_7:0 incompatible with expected float_ref."

Rusiecki commented 5 years ago

Hello, I'm running into a similar problem. When running python3.7 freeze_graph.py --input_meta_graph='/Users/banana/Downloads/2/model.ckpt-113750.meta' --input_checkpoint='/Users/banana/Downloads/2/model.ckpt-113750' --output_graph='/Users/banana/Downloads/2/frozengraph.pb' --output_node_names='bilm/Assign_7' --input_binary=True

I get an assertion error for any value inside of output_node_names =

Error message :

Loaded meta graph file '/Users/master/Downloads/2/model.ckpt-113750.meta 2019-03-30 13:31:57.446975: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA Traceback (most recent call last): File "freeze_graph.py", line 495, in <module> run_main() File "freeze_graph.py", line 492, in run_main app.run(main=my_main, argv=[sys.argv[0]] + unparsed) File "/usr/local/lib/python3.7/site-packages/tensorflow/python/platform/app.py", line 125, in run _sys.exit(main(argv)) File "freeze_graph.py", line 491, in <lambda> my_main = lambda unused_args: main(unused_args, flags) File "freeze_graph.py", line 385, in main flags.saved_model_tags, checkpoint_version) File "freeze_graph.py", line 367, in freeze_graph checkpoint_version=checkpoint_version) File "freeze_graph.py", line 229, in freeze_graph_with_def_protos variable_names_blacklist=variable_names_blacklist) File "/usr/local/lib/python3.7/site-packages/tensorflow/python/framework/graph_util_impl.py", line 232, in convert_variables_to_constants inference_graph = extract_sub_graph(input_graph_def, output_node_names) File "/usr/local/lib/python3.7/site-packages/tensorflow/python/framework/graph_util_impl.py", line 174, in extract_sub_graph _assert_nodes_are_present(name_to_node, dest_nodes) File "/usr/local/lib/python3.7/site-packages/tensorflow/python/framework/graph_util_impl.py", line 133, in _assert_nodes_are_present assert d in name_to_node, "%s is not in graph" % d AssertionError: bilm/Assign_7 is not in graph Dominiks-MacBook-Pro:Downloads master$ python3.7 freeze_graph.py --input_meta_graph='/Users/master/Downloads/2/model.ckpt-113750.meta' --input_checkpoint='/Users/master/Downloads/2/model.ckpt-113750' --output_graph='/Users/master/Downloads/2/frozengraph.pb' --output_node_names='bilm/Assign_7' --input_binary=True

The code i run is :

`# Copyright 2015 The TensorFlow Authors. All Rights Reserved. #

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

#

http://www.apache.org/licenses/LICENSE-2.0

#

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

==============================================================================

r"""Converts checkpoint variables into Const ops in a standalone GraphDef file.

This script is designed to take a GraphDef proto, a SaverDef proto, and a set of variable values stored in a checkpoint file, and output a GraphDef with all of the variable ops converted into const ops containing the values of the variables.

It's useful to do this when we need to load a single file in C++, especially in environments like mobile or embedded where we may not have access to the RestoreTensor ops and file loading calls that they rely on.

An example of command-line usage is: bazel build tensorflow/python/tools:freeze_graph && \ bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=some_graph_def.pb \ --input_checkpoint=model.ckpt-8361242 \ --output_graph=/tmp/frozen_graph.pb --output_node_names=softmax

You can also look at freeze_graph_test.py for an example of how to use it.

""" from future import absolute_import from future import division from future import print_function

import argparse import re import sys

from google.protobuf import text_format

from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session from tensorflow.python.framework import graph_util from tensorflow.python.framework import importer from tensorflow.python.platform import app from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants from tensorflow.python.tools import saved_model_utils

from tensorflow.python.training import checkpoint_management

from tensorflow.train import checkpoint_exists from tensorflow.python.training import saver as saver_lib import tensorflow as tf

def _has_no_variables(sess): """Determines if the graph has any variables.

Args: sess: TensorFlow Session.

Returns: Bool. """ for op in sess.graph.get_operations(): if op.type.startswith("Variable") or op.type.endswith("VariableOp"): return False return True

def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.

Args: input_graph_def: A GraphDef. input_saver_def: A SaverDef (optional). input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking priority. Typically the result of Saver.save() or that of tf.train.latest_checkpoint(), regardless of sharded/non-sharded or V1/V2. output_node_names: The name(s) of the output nodes, comma separated. restore_op_name: Unused. filename_tensor_name: Unused. output_graph: String where to write the frozen GraphDef. clear_devices: A Bool whether to remove device specifications. initializer_nodes: Comma separated string of initializer nodes to run before freezing. variable_names_whitelist: The set of variable names to convert (optional, by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants (optional). input_meta_graph_def: A MetaGraphDef (optional), input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and variables (optional). saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to load, in string format (optional). checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2)

Returns: Location of the output_graph_def. """ del restore_op_name, filename_tensor_name # Unused by updated loading code.

'input_checkpoint' may be a prefix if we're using Saver V2 format

if (not input_saved_model_dir and not checkpoint_exists(input_checkpoint)):

not tf.train.checkpoint_exists(input_checkpoint)):

print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1

if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1

Remove all the explicit device specifications for this node. This helps to

make the graph more portable.

if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = ""

if input_graphdef: = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver( saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map()

  # List of all partition variables. Because the condition is heuristic
  # based, the list could include false positives.
  all_parition_variable_names = [
      tensor.name.split(":")[0]
      for op in sess.graph.get_operations()
      for tensor in op.values()
      if re.search(r"/part_\d+/", tensor.name)
  ]
  has_partition_var = False

  for key in var_to_shape_map:
    try:
      tensor = sess.graph.get_tensor_by_name(key + ":0")
      if any(key in name for name in all_parition_variable_names):
        has_partition_var = True
    except KeyError:
      # This tensor doesn't exist in the graph (for example it's
      # 'global_step' or a similar housekeeping element) so skip it.
      continue
    var_list[key] = tensor

  try:
    saver = saver_lib.Saver(
        var_list=var_list, write_version=checkpoint_version)
  except TypeError as e:
    # `var_list` is required to be a map of variable names to Variable
    # tensors. Partition variables are Identity tensors that cannot be
    # handled by Saver.
    if has_partition_var:
      print("Models containing partition variables cannot be converted "
            "from checkpoint files. Please pass in a SavedModel using "
            "the flag --input_saved_model_dir.")
      return -1
    # Models that have been frozen previously do not contain Variables.
    elif _has_no_variables(sess):
      print("No variables were found in this model. It is likely the model "
            "was frozen previously. You cannot freeze a graph twice.")
      return 0
    else:
      raise e

  saver.restore(sess, input_checkpoint)
  if initializer_nodes:
    sess.run(initializer_nodes.replace(" ", "").split(","))

variable_names_whitelist = (
    variable_names_whitelist.replace(" ", "").split(",")
    if variable_names_whitelist else None)
variable_names_blacklist = (
    variable_names_blacklist.replace(" ", "").split(",")
    if variable_names_blacklist else None)

if input_meta_graph_def:
  output_graph_def = graph_util.convert_variables_to_constants(
      sess,
      input_meta_graph_def.graph_def,
      output_node_names.replace(" ", "").split(","),
      variable_names_whitelist=variable_names_whitelist,
      variable_names_blacklist=variable_names_blacklist)
else:
  output_graph_def = graph_util.convert_variables_to_constants(
      sess,
      input_graph_def,
      output_node_names.replace(" ", "").split(","),
      variable_names_whitelist=variable_names_whitelist,
      variable_names_blacklist=variable_names_blacklist)

Write GraphDef to file if output path has been given.

if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString())

return output_graph_def

def _parse_input_graph_proto(input_graph, input_binary): """Parses input tensorflow graph into GraphDef proto.""" if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.GFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) return input_graph_def

def _parse_input_meta_graph_proto(input_graph, input_binary): """Parses input tensorflow graph into MetaGraphDef proto.""" if not gfile.Exists(input_graph): print("Input meta graph file '" + input_graph + "' does not exist!") return -1 input_meta_graph_def = MetaGraphDef() mode = "rb" if input_binary else "r" with gfile.GFile(input_graph, mode) as f: if input_binary: input_meta_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_meta_graph_def) print("Loaded meta graph file '" + input_graph) return input_meta_graph_def

def _parse_input_saver_proto(input_saver, input_binary): """Parses input tensorflow Saver into SaverDef proto.""" if not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 mode = "rb" if input_binary else "r" with gfile.GFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) return saver_def

def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph=None, input_saved_model_dir=None, saved_model_tags=tag_constants.SERVING, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.

Args: input_graph: A GraphDef file to load. input_saver: A TensorFlow Saver file. input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt. input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking priority. Typically the result of Saver.save() or that of tf.train.latest_checkpoint(), regardless of sharded/non-sharded or V1/V2. output_node_names: The name(s) of the output nodes, comma separated. restore_op_name: Unused. filename_tensor_name: Unused. output_graph: String where to write the frozen GraphDef. clear_devices: A Bool whether to remove device specifications. initializer_nodes: Comma separated list of initializer nodes to run before freezing. variable_names_whitelist: The set of variable names to convert (optional, by default, all variables are converted), variable_names_blacklist: The set of variable names to omit converting to constants (optional). input_meta_graph: A MetaGraphDef file to load (optional). input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and variables (optional). saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to load, in string format. checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2). Returns: String that is the location of frozen GraphDef. """ input_graph_def = None if input_saved_model_dir: input_graph_def = saved_model_utils.get_meta_graph_def( input_saved_model_dir, saved_model_tags).graph_def elif input_graph: input_graph_def = _parse_input_graph_proto(input_graph, input_binary) input_meta_graph_def = None if input_meta_graph: input_meta_graph_def = _parse_input_meta_graph_proto( input_meta_graph, input_binary) input_saver_def = None if input_saver: input_saver_def = _parse_input_saver_proto(input_saver, input_binary) freeze_graph_with_def_protos( input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist, variable_names_blacklist, input_meta_graph_def, input_saved_model_dir, saved_model_tags.replace(" ", "").split(","), checkpoint_version=checkpoint_version)

def main(unused_args, flags): if flags.checkpoint_version == 1: checkpoint_version = saver_pb2.SaverDef.V1 elif flags.checkpoint_version == 2: checkpoint_version = saver_pb2.SaverDef.V2 else: print("Invalid checkpoint version (must be '1' or '2'): %d" % flags.checkpoint_version) return -1 freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary, flags.input_checkpoint, flags.output_node_names, flags.restore_op_name, flags.filename_tensor_name, flags.output_graph, flags.clear_devices, flags.initializer_nodes, flags.variable_names_whitelist, flags.variable_names_blacklist, flags.input_meta_graph, flags.input_saved_model_dir, flags.saved_model_tags, checkpoint_version)

def run_main(): parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument( "--input_graph", type=str, default="", help="TensorFlow \'GraphDef\' file to load.") parser.add_argument( "--input_saver", type=str, default="", help="TensorFlow saver file to load.") parser.add_argument( "--input_checkpoint", type=str, default="", help="TensorFlow variables file to load.") parser.add_argument( "--checkpoint_version", type=int, default=2, help="Tensorflow variable file format") parser.add_argument( "--output_graph", type=str, default="", help="Output \'GraphDef\' file name.") parser.add_argument( "--input_binary", nargs="?", const=True, type="bool", default=False, help="Whether the input files are in binary format.") parser.add_argument( "--output_node_names", type=str, default="", help="The name of the output nodes, comma separated.") parser.add_argument( "--restore_op_name", type=str, default="save/restore_all", help="""\ The name of the master restore operator. Deprecated, unused by updated \ loading code. """) parser.add_argument( "--filename_tensor_name", type=str, default="save/Const:0", help="""\ The name of the tensor holding the save path. Deprecated, unused by \ updated loading code. """) parser.add_argument( "--clear_devices", nargs="?", const=True, type="bool", default=True, help="Whether to remove device specifications.") parser.add_argument( "--initializer_nodes", type=str, default="", help="Comma separated list of initializer nodes to run before freezing.") parser.add_argument( "--variable_names_whitelist", type=str, default="", help="""\ Comma separated list of variables to convert to constants. If specified, \ only those variables will be converted to constants.\ """) parser.add_argument( "--variable_names_blacklist", type=str, default="", help="""\ Comma separated list of variables to skip converting to constants.\ """) parser.add_argument( "--input_meta_graph", type=str, default="", help="TensorFlow \'MetaGraphDef\' file to load.") parser.add_argument( "--input_saved_model_dir", type=str, default="", help="Path to the dir with TensorFlow \'SavedModel\' file and variables.") parser.add_argument( "--saved_model_tags", type=str, default="serve", help="""\ Group of tag(s) of the MetaGraphDef to load, in string format,\ separated by \',\'. For tag-set contains multiple tags, all tags \ must be passed in.\ """) flags, unparsed = parser.parse_known_args()

my_main = lambda unused_args: main(unused_args, flags) app.run(main=my_main, argv=[sys.argv[0]] + unparsed)

if name == 'main': run_main() `

The files I have are :

-rwxrwxrwx@ 1 master staff 91 Mar 30 03:39 checkpoints -rwxrwxrwx@ 1 master staff 34533780 Mar 27 20:53 events.out.tfevents.1553461572.elmo- -rwxrwxrwx 1 master staff 1569619104 Mar 27 21:16 model.ckpt-113750.data -rwxrwxrwx 1 master staff 1569619104 Mar 27 21:16 model.ckpt-113750.data-00000-of-00001 -rwxrwxrwx 1 master staff 3275 Mar 27 21:16 model.ckpt-113750.index -rwxrwxrwx@ 1 master staff 7509658 Mar 27 20:21 model.ckpt-113750.meta

Is there any update how to get a pb file out the elmo training ?

carolmanderson commented 4 years ago

@Rusiecki did you ever find a solution for this?

mohammedayub44 commented 4 years ago

I wanted to create a SavedModel format from the trained model to be used as REST Endpoints using Tensorflow Serving. Is there an easy way to do so.

Thanks!

carolmanderson commented 4 years ago

@mohammedayub44 Assuming what you want to do is use the trained model to generate embeddings: I figured out a way to do it by removing the ops that pass the state forward between batches. My model is now running on TF Serving in production. There doesn't seem to be an easy way to deploy the model in its stateful form. If you wanted to maintain state between batches, you would need to modify the graph to take the previous batch's states as input and produce the current states as output, and then explicitly pass them forward with each call to your endpoint.

In my case, I wanted to turn off statefulness anyway, because it causes non-deterministic behavior and makes testing of the deployed model difficult. Depending on the application, I found that turning off statefulness caused a 0-0.5% decrease in the F1 score of the model consuming the ELMo embeddings.

To turn off statefulness when computing embeddings, these lines in model.py should be removed or commented out: https://github.com/allenai/bilm-tf/blob/7cffee2b0986be51f5e2a747244836e1047657f4/bilm/model.py#L587-L593

This is the code I used to export the graph. I did it in two steps -- I found that after the first step, my model didn't have the TF serving tags, and the second step was necessary to add the tags. There's probably a more direct way of doing this.

Step 1:

import tensorflow as tf

from bilm import BidirectionalLanguageModel  #Note: make sure to comment out the lines referenced above in your copy of bilm-tf before this import

elmo_weight_file = '/path/to/my_ckpt_weights.hd5'
elmo_options_file = '/path/to/my_options.json'
output_file = '/path/to/my_saved_model.pb'

model = BidirectionalLanguageModel(elmo_options_file, elmo_weight_file)

graph = tf.Graph()

with graph.as_default():
    ids_placeholder = tf.placeholder('int32', shape=(None, None, 50))
    ops = model(ids_placeholder)
    session = tf.Session()
    session.run(tf.global_variables_initializer())

output_node_names="concat_3"

input_graph_def = session.graph.as_graph_def()  
output_graph_def = tf.graph_util.convert_variables_to_constants(
            session, 
            input_graph_def, 
            output_node_names.split(",") ) 

with tf.gfile.GFile(output_file, "wb") as f:
    f.write(output_graph_def.SerializeToString())

Step 2:

import tensorflow as tf
from tensorflow.saved_model import simple_save

def load_graph(model_file, returnElements= None):
    graph = tf.Graph()
    graph_def = tf.GraphDef()
    with open(model_file, "rb") as f:
        graph_def.ParseFromString(f.read())
    returns = None
    with graph.as_default():
        returns = tf.import_graph_def(graph_def, return_elements= returnElements)
    if returnElements is None:
        return graph
    return graph, returns

old_graph = "/path/to/my_saved_model.pb"
new_graph = "/path/to/my_new_saved_model.pb"

graph = load_graph(old_graph)
with tf.Session(graph = graph) as sess:
    with graph.as_default():
        layers = [n.name for n in graph.as_graph_def().node]
        output_node_name = layers.pop() + ":0"
        input_node_name = layers.pop(0) + ":0"
    output_node = tf.get_default_graph().get_tensor_by_name(output_node_name)
    input_node = tf.get_default_graph().get_tensor_by_name(input_node_name)

    inputs = {input_node.name : input_node}
    outputs = {output_node.name : output_node}
    simple_save(sess, new_graph, inputs, outputs)

And here's a code snippet to check whether your export worked. If the ops that pass the state forward haven't been removed, this will raise an error like ValueError: Input 0 of node bilm/Assign was passed float from bilm/Variable:0 incompatible with expected float_ref.

frozen_graph = "/path/to/my_new_saved_model.pb"
with tf.gfile.GFile(frozen_graph, "rb") as f:
    restored_graph_def = tf.GraphDef()
    restored_graph_def.ParseFromString(f.read())
mohammedayub44 commented 4 years ago

@carolmanderson Thanks for the detailed answer. I'm trying to connect to above model.py link , it doesn't seem to work. If that's a direct clone of this repo. Are you referring to these lines in model.py

with tf.control_dependencies([layer_output]):
     # update the initial states
     for i in range(2):
          new_state = tf.concat(
          [final_state[i][:batch_size, :],
          init_states[i][batch_size:, :]], axis=0)
          state_update_op = tf.assign(init_states[i], new_state)
          update_ops.append(state_update_op)

Thanks !

carolmanderson commented 4 years ago

Yes, sorry, I was signed into two different Github accounts at once and got confused. I've updated it above.

mohammedayub44 commented 4 years ago

No Problem. :) Couple of thoughts - 1) I got the ValueError... as you said. Therefore commented out the state fullness part (Had to comment out line 394 as well) and then ran Step 1. Ran fine. Checked the frozen graph export from Step1 with the code. Worked fine.

2) In Step 2 - simple_save() function doesn't output anything in the variables folder. Guessing all variable data in the frozen graph that's generated and hence variable file is not required. Checking the graph export from Step2, gives me error (tried in both TF1 and TF2) image

However in TF2 I tried your sugesstion from #238 using tf.saved_model.load() and it works fine 👍

image

Cheers !