ShannonAI / service-streamer

Boosting your Web Services of Deep Learning Applications.
Apache License 2.0
1.22k stars 187 forks source link

当模型predict参数为ndarray的时候会将参数修改成list #100

Open danerlt opened 1 year ago

danerlt commented 1 year ago

我有一个推理函数如下:

@log_execution_time
def batch_predict(src_data: pd.DataFrame) -> list:
    torque_angle_trace: ndarray = preprocess(src_data)
    start = time.time()
    predict_res = model.predict(torque_angle_trace)
    ene = time.time()
    logger.info(f"单个预测耗时:{ene - start}")
    res = predict_res.tolist()
    return res

其中model.predict接收的参数是一个ndarry, 下面是使用ThreadedStreamer之后的代码:

stream = ThreadedStreamer(model.predict, batch_size=10, max_latency=0.1)

def batch_predict_stream(src_data: pd.DataFrame) -> list:
    start = time.time()
    torque_angle_trace: ndarray = preprocess(src_data)
    predict_res = stream.predict([torque_angle_trace])
    ene = time.time()
    logger.info(f"stream预测耗时:{ene - start}")
    return predict_res

在调用stream.predict的时候我已经将数据处理成ndarray传进去了。 然后在运行的时候提示TypeError: X is not of a supported input data type.X must be in a supported mtype format for Panel, found <class 'list'>Use datatypes.check_is_mtype to check conformance with specifications.

我查看源码之后发现问题在下图将队列中取到的数据放到了一个list中然后传递给predict函数

image

请问是我使用ThreadedStreamer方法不对,还是predict函数不支持ndarray的参数。

danerlt commented 1 year ago

针对上面的问题我将183行处的循环改成了如下所示,就正常运行了。

        model_inputs = []
        is_ndarray = False
        for i in batch:
            model_input = i[3]
            if isinstance(model_input, np.ndarray):
                is_ndarray = True
            model_inputs.append(model_input)
        if is_ndarray:
            model_inputs = np.vstack(model_inputs)
        model_outputs = self.model_predict(model_inputs)

        if is_ndarray:
            model_outputs = model_outputs.tolist()