aws / sagemaker-pytorch-inference-toolkit

Toolkit for allowing inference and serving with PyTorch on SageMaker. Dockerfiles used for building SageMaker Pytorch Containers are at https://github.com/aws/deep-learning-containers.
Apache License 2.0
131 stars 70 forks source link

Batch Inference does not work when using the default handler #121

Open nikhil-sk opened 2 years ago

nikhil-sk commented 2 years ago

Describe the bug

  1. In batch inference, the model-server (in this case, torchserve) will return a 'batch' i.e list of requests to the handler. The handler is expected to process them and send back the responses. This would be a list of 'batch-size' responses.
  2. Currently, the pt toolkit uses the transform() function from the base inference-toolkit to receive requests from the model server, and process them by calling the _transform_fn() i.e [which calls _input_fn, _predict_fn, _output_fn].
  3. However, it seems to only process the 'first' request in the batch: https://github.com/aws/sagemaker-inference-toolkit/blob/master/src/sagemaker_inference/transformer.py#L114
  4. When using the default handler, all but the first requests get dropped.
  5. This restricts using the default handler for batch inference.

To reproduce A clear, step-by-step set of instructions to reproduce the bug:

  1. A run-through of this notebook: https://github.com/aws/amazon-sagemaker-examples/blob/main/sagemaker-python-sdk/pytorch_batch_inference/sagemaker_batch_inference_torchserve.ipynb results in the failure (log attached in respective section). [This notebook has a workaround PR using custom container: https://github.com/aws/amazon-sagemaker-examples/pull/3395]

Expected behavior

  1. The transform() function should process all requests in a batch, and return a list of responses which equals in size to the input list.
  2. Attaching an untested suggestion in the form of a PR: link

Screenshots or logs

  1. A run of the notebook results in:
    2022-04-27T20:51:36,026 [INFO ] W-9000-model_1.0 org.pytorch.serve.wlm.WorkerThread - Flushing req. to backend at: 1651092696026
    2022-04-27T20:51:36,028 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - Backend received inference at: 1651092696
    2022-04-27T20:51:36,028 [WARN ] W-9000-model_1.0-stderr MODEL_LOG - Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 42.8kB/s]
    2022-04-27T20:51:36,028 [WARN ] W-9000-model_1.0-stderr MODEL_LOG - Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
    2022-04-27T20:51:36,119 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - INPUT1
    2022-04-27T20:51:36,120 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - INPUT2
    2022-04-27T20:51:36,120 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - Got input Data: {Bloomberg has decided to publish a new report on global economic situation.}
    2022-04-27T20:51:36,120 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - PRED SequenceClassifierOutput(loss=None, logits=tensor([[ 0.1999, -0.2964]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
    2022-04-27T20:51:36,120 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - PREDICTION ['Not Accepted']
    2022-04-27T20:51:36,120 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - model: model, number of batch response mismatched, expect: 3, got: 1.
    2022-04-27T20:51:36,121 [INFO ] W-9000-model_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 94

System information A description of your system. Please provide:

Additional context Add any other context about the problem here.