yanwii / ChineseNER

基于Bi-GRU + CRF 的中文机构名、人名识别, 支持google bert模型
164 stars 41 forks source link

能否将该项目配合docker+tensorflow-serving作为服务来运行? #16

Closed zdx1012 closed 4 years ago

zdx1012 commented 4 years ago

大佬,请问下该模型的输入输出项都有哪些? 最近在看将模型做成服务的形式,需要这几个参数

            labeling_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={
                        "input_ids":
                            bert_input_ids,
                        "segment_ids":
                            bert_segment_ids,
                        "input_mask":
                            bert_input_mask,
                        "dropout":
                            bert_dropout,
                    },
                    outputs={
                        "targets":
                            bert_targets,
                    },
                    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

按照网上的教程,改写成这个样子,这些输入输出参数是对的吗?望指点下,感谢!

zdx1012 commented 4 years ago
def restore_and_save(input_checkpoint, export_path_base):
    checkpoint_file = tf.train.latest_checkpoint(input_checkpoint)
    graph = tf.Graph()

    with graph.as_default():
        session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        sess = tf.Session(config=session_conf)

        with sess.as_default():
            # 载入保存好的meta graph,恢复图中变量,通过SavedModelBuilder保存可部署的模型
            print("{}.meta".format(checkpoint_file))
            saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)
            print(graph.get_name_scope())

            export_path_base = export_path_base
            export_path = os.path.join(
                tf.compat.as_bytes(export_path_base),
                tf.compat.as_bytes(str(count)))
            print('Exporting trained model to', export_path)
            builder = tf.saved_model.builder.SavedModelBuilder(export_path)

            # 建立签名映射,需要包括计算图中的placeholder(ChatInputs, SegInputs, Dropout)和我们需要的结果(project/logits,crf_loss/transitions)
            """
            build_tensor_info:建立一个基于提供的参数构造的TensorInfo protocol buffer,
            输入:tensorflow graph中的tensor;

            输出:基于提供的参数(tensor)构建的包含TensorInfo的protocol buffer
                        get_operation_by_name:通过name获取checkpoint中保存的变量,能够进行这一步的前提是在模型保存的时候给对应的变量赋予name
            """
            bert_input_ids = tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("bert_input_ids").outputs[0])
            bert_input_mask = tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("bert_input_mask").outputs[0])
            bert_segment_ids = tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("bert_segment_ids").outputs[0])
            bert_dropout = tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("bert_dropout").outputs[0])

            logits = tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("logits/logits").outputs[0])

            trans = tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("loss_layer/transitions").outputs[0])
            """
            signature_constants:SavedModel保存和恢复操作的签名常量。
            在序列标注的任务中,这里的method_name是"tensorflow/serving/predict"
            """
            # 定义模型的输入输出,建立调用接口与tensor签名之间的映射
            labeling_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={
                        "input_ids":
                            bert_input_ids,
                        "segment_ids":
                            bert_segment_ids,
                        "input_mask":
                            bert_input_mask,
                        "dropout":
                            bert_dropout,
                    },
                    outputs={
                         "logits":
                             logits,
                         "trans":
                             trans,
                    },
                    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

            """
            tf.group : 创建一个将多个操作分组的操作,返回一个可以执行所有输入的操作
            """
            legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

            """
            add_meta_graph_and_variables:建立一个Saver来保存session中的变量,
                                          输出对应的原图的定义,这个函数假设保存的变量已经被初始化;
                                          对于一个SavedModelBuilder,这个API必须被调用一次来保存meta graph;
                                          对于后面添加的图结构,可以使用函数 add_meta_graph()来进行添加
            """
            # 建立模型名称与模型签名之间的映射
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                # 保存模型的方法名,与客户端的request.model_spec.signature_name对应
                signature_def_map={
                    tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                        labeling_signature},
                legacy_init_op=legacy_init_op)

            builder.save()
            print("Build Done")