pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.17k stars 849 forks source link

Method `preprocess` is not called for custom handler #626

Closed BorisPolonsky closed 4 years ago

BorisPolonsky commented 4 years ago

Your issue may already be reported! Please search on the issue tracker before creating one.

Context

Your Environment

Expected Behavior

the custom preprocess method should be called for custom handler when model server receives the request.

Current Behavior

Method preprocess in my handler is never called. To prove this I add sys.stdout.write and exit(-1) in method preprocess and handle respectively, I can only see messages to be written to stdout defined within handler method made is way to the terminal, while the preprocess counterpart didn't. Same thing for the exit statement, which only takes effect in method handle, which suggest that the preprocess method is never called.

Possible Solution

Steps to Reproduce

  1. create .mar file from serialized torchscript torch-model-archiver --model-name <model_name> --version 0.1 --serialized-file ./script-module-20200817-172517.pt --handler ./ts_handler.py
  2. serve the model in official docker image docker run --rm -it -e LANG=C.UTF-8 --name torchserve -v /directory/containing/the/mar/archive/:/home/model-server/model-store:ro -p 8080:8080 -p 8081:8081 --gpus all pytorch/torchserve:0.2.0-cuda10.1-cudnn7-runtime torchserve --start --ts-config /home/model-server/config.properties --models msra_ner=/home/model-server/model-store/<model-name>.mar
  3. Send request: curl localhost:8080/predictions/<model_name> -T test.txt

The custome handler ts_handler.py is defined as

from abc import ABC
import torch
import os
import sys
from ts.torch_handler.base_handler import BaseHandler

class NamedEntityRecognitionHandler(BaseHandler, ABC):
    """
    A custom model handler implementation.
    """

    def __init__(self):
        self._context = None
        self.initialized = False
        self.model = None
        self.device = None

    def initialize(self, context):
        """
        Invoke by torchserve for loading a model
        :param context: context contains model server system properties
        :return:
        """

        #  load the model
        self.manifest = context.manifest

        properties = context.system_properties
        model_dir = properties.get("model_dir")
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
        # self.device = torch.device("cuda:0")
        # Read model serialize/pt file
        serialized_file = self.manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model.pt file")

        self.model = torch.jit.load(model_pt_path)

        self.initialized = True

    def handle(self, data, context):
        """
        Invoke by TorchServe for prediction request.
        Do pre-processing of data, prediction using model and postprocessing of prediciton output
        :param data: Input data for prediction
        :param context: Initial context contains model server system properties.
        :return: prediction output
        """
        sys.stdout.write("=====stdout.write works fine=========")
        # exit(-1) # Server will return "worker died" if not commented out
        pred_out = self.model.forward(data)
        return pred_out

    def preprocess(self, data):
        """
        Transform raw input into model input data.
        :param batch: list of raw requests, should match batch size
        :return: list of preprocessed model input data
        """
        # Take the input data and make it inference ready
        sys.stdout.write("===========Preprocessing==========") # Nothing in stdout
        exit(-1) # Trying to exit the program, no luck
        preprocessed_data = data[0].get("data")
        if preprocessed_data is None:
            preprocessed_data = data[0].get("body")
        preprocessed_data = preprocessed_data.decode('utf-8')
        return preprocessed_data

Failure Logs [if any]

Note that ===========Preprocessing========== never made its way to the log

2020-08-18 10:02:31,802 [INFO ] W-9000-msra_ner_0.1 TS_METRICS - W-9000-msra_ner_0.1.ms:3254|#Level:Host|#hostname:274a635d0fe7,timestamp:1597744951
2020-08-18 10:02:31,807 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - =====stdout.write works fine=========Invoking custom service failed.
2020-08-18 10:02:31,808 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Traceback (most recent call last):
2020-08-18 10:02:31,808 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/venv/lib/python3.6/site-packages/ts/service.py", line 100, in predict
2020-08-18 10:02:31,808 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     ret = self._entry_point(input_batch, self.context)
2020-08-18 10:02:31,808 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/model-server/tmp/models/2aa723faff5d454797674946580cdb11/ts_handler.py", line 53, in handle
2020-08-18 10:02:31,808 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     pred_out = self.model.forward(data)
2020-08-18 10:02:31,808 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - RuntimeError: forward() Expected a value of type 'Tensor' for argument 'inputs' but instead found type 'list'.
2020-08-18 10:02:31,808 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Position: 1
2020-08-18 10:02:31,808 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Value: [{'body': bytearray(b'hello world\n')}]
2020-08-18 10:02:31,808 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Declaration: forward(__torch__.model.create.BiRNNCrf self, Tensor inputs, Tensor sequence_lengths) -> ((Tensor, Tensor))
2020-08-18 10:02:31,809 [INFO ] W-9000-msra_ner_0.1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Cast error details: Unable to cast Python instance to C++ type (compile in debug mode for details)
2020-08-18 10:02:31,809 [INFO ] W-9000-msra_ner_0.1 org.pytorch.serve.wlm.WorkerThread - Backend response time: 6
2020-08-18 10:02:31,817 [INFO ] W-9000-msra_ner_0.1 ACCESS_LOG - /172.17.0.1:49074 "PUT /predictions/msra_ner HTTP/1.1" 503 881
2020-08-18 10:02:31,817 [INFO ] W-9000-msra_ner_0.1 TS_METRICS - Requests5XX.Count:1|#Level:Host|#hostname:274a635d0fe7,timestamp:null
2020-08-18 10:02:31,818 [DEBUG] W-9000-msra_ner_0.1 org.pytorch.serve.wlm.Job - Waiting time ns: 861492032, Inference time ns: 876052724
harshbafna commented 4 years ago

@BorisPolonsky: TorchServe only calls the handle function, which is the default entry point for your model handler which in turn should call the pre-process function. Since you are overriding the default behavior of handle function from BaseHandler, you will need to take care of the function call.

Also, it is recommended that you should not overwrite the handle function from BaseHandler unless you want to change the default handling like adding another function call in your handle pipeline.

BorisPolonsky commented 4 years ago

@BorisPolonsky: TorchServe only calls the handle function, which is the default entry point for your model handler which in turn should call the pre-process function. Since you are overriding the default behavior of handle function from BaseHandler, you will need to take care of the function call.

Also, it is recommended that you should not overwrite the handle function from BaseHandler unless you want to change the default handling like adding another function call in your handle pipeline.

Thanks for the clarification. Apparently this is not a bug. I'll close this issue.