tensorflow / serving

A flexible, high-performance serving system for machine learning models
https://www.tensorflow.org/serving
Apache License 2.0
6.18k stars 2.19k forks source link

Tensorflow serving about "--enable_batching" flag #1682

Closed baopanda closed 4 years ago

baopanda commented 4 years ago

Hi all, I am using tensorflow for serving models but if I use the same script after starting the server with --enabled_batching flag, it returns me this error "Batching session Run() input tensors must have equal 0th-dimension size". Anyone can help me? Thank you in advance <3

rmothukuru commented 4 years ago

@baopanda, Can you please confirm if you have configured Batching parameter File. Please find this TF Serving Link for more details. If you have configured as shown in that link and if the issue still persists, In order to expedite the trouble-shooting process, please provide a code snippet to reproduce the issue reported here. Also share the SignatureDefs of your Model as well. Thanks!

baopanda commented 4 years ago

@rmothukuru Yeppp, I have configured Batching parameters:

max_batch_size { value: 32 } batch_timeout_micros { value: 5000 }

and the code request to grpc server:


    host = "localhost"
    model_spec_name = "segment"
    port = 8500
    model_sig_name = "serving_default"

    host = host.replace("http://", "").replace("https://", "")
    channel = grpc.insecure_channel("{}:{}".format(host, port))
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

    # Create PredictRequest ProtoBuf from image data
    request = predict_pb2.PredictRequest()
    request.model_spec.name = model_spec_name
    request.model_spec.signature_name = model_sig_name

    #f_f
    request.inputs['f_f:0'].CopyFrom(
        tf.make_tensor_proto(f_f,dtype=np.int64,shape=f_f.shape)
    )
    request.inputs['f_p:0'].CopyFrom(
        tf.make_tensor_proto(f_p,dtype=np.int64,shape=f_p.shape)
    )
    request.inputs['b_f:0'].CopyFrom(
        tf.make_tensor_proto(b_f,dtype=np.int64, shape=b_f.shape)
    )
    request.inputs['b_p:0'].CopyFrom(
        tf.make_tensor_proto(b_p,dtype=np.int64, shape=b_p.shape)
    )
    request.inputs['w_f:0'].CopyFrom(
        tf.make_tensor_proto(w_f,dtype=np.int64, shape=w_f.shape)
    )

    start = time.time()
    # Call the TFServing Predict API
    predict_response = stub.Predict(request, timeout=timeout)
    print(">>> Inference time: {}'s".format(time.time() - start))

    return predict_response`

and my SignatureDefs:

``MetaGraphDef` with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['b_f:0'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: b_f:0
    inputs['b_p:0'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: b_p:0
    inputs['f_f:0'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: f_f:0
    inputs['f_p:0'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: f_p:0
    inputs['w_f:0'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: w_f:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['test_output:0'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1, -1, -1)
        name: test_output:0
  Method name is: `tensorflow/serving/predict`

But I think my error maybe related to the batching flags where it will auto recognize the 0th dimension as the batch numbers as well.
Can you help me to fix this bug, I will very apppreciate. 
baopanda commented 4 years ago

@shadowdragon89 @rmothukuru you have any ideas?

shadowdragon89 commented 4 years ago

In order to be batchable, we need to have the the same first dimension size. This is the complains from the server, which they get requests in one batch with different first dimension size. Is there a reproduce step?