yuenshome / yuenshome.github.io

https://yuenshome.github.io
MIT License
83 stars 15 forks source link

实时视频风格迁移项目:fritz-style-transfer #8

Open ysh329 opened 5 years ago

ysh329 commented 5 years ago
ysh329 commented 5 years ago

blank

ysh329 commented 5 years ago

fritzlabs/fritz-style-transfer

最关键三步和两个包

  1. models.SmallStyleTransferNetwork(位于fritz-style-transfer):from style_transfer import models
  2. _freeze_graph(位于tensorflow):from tensorflow.python.tools import freeze_graph
  3. _optimize_graph(位于tensorflow):from tensorflow.python.tools import optimize_for_inference_lib
ysh329 commented 5 years ago

1. models.SmallStyleTransferNetwork

class SmallStyleTransferNetwork(StyleTransferNetwork):

    @classmethod
    def build(cls, image_size, alpha=1.0, input_tensor=None, checkpoint_file=None):
        """Build a Smaller Transfer Network Model using keras' functional API.
        This architecture removes some blocks of layers and reduces the size
        of convolutions to save on computation.
        Args:
            image_size - the size of the input and output image (H, W)
            alpha - a width parameter to scale the number of channels by
        Returns:
            model: a keras model object
        """
        x = keras.layers.Input(
            shape=(image_size[0], image_size[1], 3), tensor=input_tensor)
        out = cls._convolution(x, int(alpha * 32), 3, strides=2)
        out = cls._convolution(out, int(alpha * 64), 3, strides=2)
        out = cls._residual_block(out, int(alpha * 64))
        out = cls._residual_block(out, int(alpha * 64))
        out = cls._residual_block(out, int(alpha * 64))
        out = cls._upsample(out, int(alpha * 64), 3)
        out = cls._upsample(out, int(alpha * 32), 3, size=2)
        # Add a layer of padding to keep sizes consistent.
        # out = keras.layers.ZeroPadding2D(padding=(1, 1))(out)
        out = cls._convolution(out, 3, 3, relu=False, padding='same')
        # Restrict outputs of pixel values to -1 and 1.
        out = keras.layers.Activation('tanh')(out)
        # Deprocess the image into valid image data. Note we'll need to define
        # a custom layer for this in Core ML as well.
        out = layers.DeprocessStylizedImage()(out)
        model = keras.models.Model(inputs=x, outputs=out)

        # Optionally load weights from a checkpoint
        if checkpoint_file:
            logger.info(
                'Loading weights from checkpoint: %s' % checkpoint_file
            )
            if checkpoint_file.startswith('gs://'):
                checkpoint_file = utils.copy_file_from_gcs(checkpoint_file)
            model.load_weights(checkpoint_file, by_name=True)
        return model

    @classmethod
    def _convolution(
            cls, x, n_filters, kernel_size, strides=1,
            padding='same', relu=True, use_bias=False):
        """Create a convolution block.
        This block consists of a convolution layer, normalization, and an
        optional RELU activation.
        Args:
            x - a keras layer as input
            n_filters - the number of output dimensions
            kernel_size - an integer or tuple specifying the (width, height) of
                         the 2D convolution window
            strides - An integer or tuple/list of 2 integers, specifying the
                      strides of the convolution along the width and height.
                      Default 1.
            padding: one of "valid" or "same" (case-insensitive).
            relu - a bool specifying whether or not a RELU activation is
                   applied. Default True.
            use_bias = a bool specifying whether or not to use a bias term
        """
        out = keras.layers.convolutional.Conv2D(
            n_filters,
            kernel_size,
            strides=strides,
            padding=padding,
            use_bias=use_bias
        )(x)

        # We are using the keras-contrib library from @farizrahman4u for
        # an implementation of Instance Normalization. Note here that we are
        # specifying the normalization axis to be -1, or the channel axis.
        # By default this is None and simple Batch Normalization is applied.
        out = keras_contrib.layers.normalization.InstanceNormalization(
            axis=-1)(out)
        if relu:
            out = keras.layers.Activation('relu')(out)
        return out
ysh329 commented 5 years ago

2. _freeze_graph

def _freeze_graph(model, basename, output_dir):
    name, _ = os.path.splitext(basename)

    saver = tf.train.Saver()

    with keras.backend.get_session() as sess:
        checkpoint_filename = os.path.join(output_dir, '%s.ckpt' % name)
        output_graph_filename = os.path.join(output_dir, '%s_frozen.pb' % name)
        saver.save(sess, checkpoint_filename)
        tf.train.write_graph(
            sess.graph_def, output_dir, '%s_graph_def.pbtext' % name
        )

        freeze_graph.freeze_graph(
            input_graph=os.path.join(output_dir, '%s_graph_def.pbtext' % name),
            input_saver='',
            input_binary=False,
            input_checkpoint=checkpoint_filename,
            output_graph=output_graph_filename,
            output_node_names='deprocess_stylized_image_1/mul',
            restore_op_name="save/restore_all",
            filename_tensor_name="save/Const:0",
            clear_devices=True,
            initializer_nodes=None
        )
        logger.info('Saved frozen graph to: %s' % output_graph_filename)

tensorflow/python/tools/freeze_graph.py#L288-L363下面找到tensorflow/freeze_graph.py at master · tensorflow/tensorflow

def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_whitelist="",
                 variable_names_blacklist="",
                 input_meta_graph=None,
                 input_saved_model_dir=None,
                 saved_model_tags=tag_constants.SERVING,
                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants.
  Args:
    input_graph: A `GraphDef` file to load.
    input_saver: A TensorFlow Saver file.
    input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated list of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted),
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph: A `MetaGraphDef` file to load (optional).
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
                           variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format.
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2).
  Returns:
    String that is the location of frozen GraphDef.
  """
  input_graph_def = None
  if input_saved_model_dir:
    input_graph_def = saved_model_utils.get_meta_graph_def(
        input_saved_model_dir, saved_model_tags).graph_def
  elif input_graph:
    input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
  input_meta_graph_def = None
  if input_meta_graph:
    input_meta_graph_def = _parse_input_meta_graph_proto(
        input_meta_graph, input_binary)
  input_saver_def = None
  if input_saver:
    input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
  freeze_graph_with_def_protos(
      input_graph_def,
      input_saver_def,
      input_checkpoint,
      output_node_names,
      restore_op_name,
      filename_tensor_name,
      output_graph,
      clear_devices,
      initializer_nodes,
      variable_names_whitelist,
      variable_names_blacklist,
      input_meta_graph_def,
      input_saved_model_dir,
      saved_model_tags.replace(" ", "").split(","),
      checkpoint_version=checkpoint_version)

上面函数关键的是freeze_graph_with_def_protos这个函数,其实现了Converts all variables in a graph and checkpoint into constants.。因为太长我省略了,感兴趣可以点开左边的连接,不过这里我给出其调用转为constants的函数部分:

      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

graph_util.convert_variables_to_constants实现(tensorflow/graph_util.py at master · tensorflow/tensorflow)如下:

from tensorflow.python.framework.graph_util_impl import convert_variables_to_constants

其位于tensorflow.python.framework.graph_util_impltensorflow/graph_util_impl.py#L215-L302 at master · tensorflow/tensorflow)实现的convert_variables_to_constants函数内容如下:

@deprecation.deprecated(
    date=None,
    instructions="Use tf.compat.v1.graph_util.convert_variables_to_constants")
@tf_export(v1=["graph_util.convert_variables_to_constants"])
def convert_variables_to_constants(sess,
                                   input_graph_def,
                                   output_node_names,
                                   variable_names_whitelist=None,
                                   variable_names_blacklist=None):
  """Replaces all the variables in a graph with constants of the same values.
  If you have a trained graph containing Variable ops, it can be convenient to
  convert them all to Const ops holding the same values. This makes it possible
  to describe the network fully with a single GraphDef file, and allows the
  removal of a lot of ops related to loading and saving the variables.
  Args:
    sess: Active TensorFlow session containing the variables.
    input_graph_def: GraphDef object holding the network.
    output_node_names: List of name strings for the result nodes of the graph.
    variable_names_whitelist: The set of variable names to convert (by default,
                              all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants.
  Returns:
    GraphDef containing a simplified version of the original.
  """
  # This graph only includes the nodes needed to evaluate the output nodes, and
  # removes unneeded nodes like those involved in saving and assignment.
  inference_graph = extract_sub_graph(input_graph_def, output_node_names)

  found_variables = {}
  variable_names = []
  variable_dict_names = []
  for node in inference_graph.node:
    if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
      variable_name = node.name
      if ((variable_names_whitelist is not None and
           variable_name not in variable_names_whitelist) or
          (variable_names_blacklist is not None and
           variable_name in variable_names_blacklist)):
        continue
      variable_dict_names.append(variable_name)
      if node.op == "VarHandleOp":
        variable_names.append(variable_name + "/Read/ReadVariableOp:0")
      else:
        variable_names.append(variable_name + ":0")
  if variable_names:
    returned_variables = sess.run(variable_names)
  else:
    returned_variables = []
  found_variables = dict(zip(variable_dict_names, returned_variables))
  logging.info("Froze %d variables.", len(returned_variables))

  output_graph_def = graph_pb2.GraphDef()
  how_many_converted = 0
  for input_node in inference_graph.node:
    output_node = node_def_pb2.NodeDef()
    if input_node.name in found_variables:
      output_node.op = "Const"
      output_node.name = input_node.name
      dtype = input_node.attr["dtype"]
      data = found_variables[input_node.name]
      output_node.attr["dtype"].CopyFrom(dtype)
      output_node.attr["value"].CopyFrom(
          attr_value_pb2.AttrValue(
              tensor=tensor_util.make_tensor_proto(
                  data, dtype=dtype.type, shape=data.shape)))
      how_many_converted += 1
    elif input_node.op == "ReadVariableOp" and (
        input_node.input[0] in found_variables):
      # The preceding branch converts all VarHandleOps of ResourceVariables to
      # constants, so we need to convert the associated ReadVariableOps to
      # Identity ops.
      output_node.op = "Identity"
      output_node.name = input_node.name
      output_node.input.extend([input_node.input[0]])
      output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
      if "_class" in input_node.attr:
        output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
    else:
      output_node.CopyFrom(input_node)
    output_graph_def.node.extend([output_node])

  output_graph_def.library.CopyFrom(inference_graph.library)
  logging.info("Converted %d variables to const ops.", how_many_converted)
  return output_graph_def
ysh329 commented 5 years ago

3. _optimize_graph

def _optimize_graph(basename, output_dir):
    name, _ = os.path.splitext(basename)
    frozen_graph_filename = os.path.join(output_dir, '%s_frozen.pb' % name)
    graph_def = load_graph_def(frozen_graph_filename)

    optimized_graph = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def=graph_def,
        input_node_names=['input_1'],
        placeholder_type_enum=dtypes.float32.as_datatype_enum,
        output_node_names=['deprocess_stylized_image_1/mul'],
        toco_compatible=True
    )

    optimized_graph_filename = os.path.basename(
        frozen_graph_filename).replace('frozen', 'optimized')
    optimized_graph_filename = optimized_graph_filename
    tf.train.write_graph(
        optimized_graph, output_dir, optimized_graph_filename, as_text=False
    )
    logger.info('Saved optimized graph to: %s' %
                os.path.join(output_dir, optimized_graph_filename))

位于tensorflow/optimize_for_inference_lib.py#L89-L119 at master · tensorflow/tensorflowoptimize_for_inference函数实现如下:

def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
                           placeholder_type_enum, toco_compatible=False):
  """Applies a series of inference optimizations on the input graph.
  Args:
    input_graph_def: A GraphDef containing a training model.
    input_node_names: A list of names of the nodes that are fed inputs during
      inference.
    output_node_names: A list of names of the nodes that produce the final
      results.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.
    toco_compatible: Boolean, if True, only runs optimizations that result in
      TOCO compatible graph operations (default=False).
  Returns:
    An optimized version of the input graph.
  """
  ensure_graph_is_valid(input_graph_def)
  optimized_graph_def = input_graph_def
  optimized_graph_def = strip_unused_lib.strip_unused(
      optimized_graph_def, input_node_names, output_node_names,
      placeholder_type_enum)
  optimized_graph_def = graph_util.remove_training_nodes(
      optimized_graph_def, output_node_names)
  optimized_graph_def = fold_batch_norms(optimized_graph_def)
  if not toco_compatible:
    optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
                                               output_node_names)
  ensure_graph_is_valid(optimized_graph_def)
  return optimized_graph_def

其主要做了下面几件事:

  1. strip_unused_lib.strip_unused
  2. graph_util.remove_training_nodes
  3. fold_batch_norms
  4. fuse_resize_and_conv
ysh329 commented 5 years ago

@alonechen