pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.22k stars 863 forks source link

Batch processing of text failed #2119

Open revantemp3 opened 1 year ago

revantemp3 commented 1 year ago

Hi,

I used this custom handler to serve Huggingface's facebook mbart-large-50-many-to-many-mmt.

from abc import ABC
import json
import logging
import os
import torch
from transformers import MBartForConditionalGeneration, MBartTokenizer, MBart50TokenizerFast
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
class MbartHandler(BaseHandler, ABC):
    """
    Transformers text classifier handler class. This handler takes a text (string) and
    as input and returns the classification text based on the serialized transformers checkpoint.
    """
    def __init__(self):
        super(MbartHandler, self).__init__()
        self.initialized = False

    def initialize(self, ctx):
        self.manifest = ctx.manifest
        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
        #self.device = "cuda:0"
        logger.info(f"Using {str(self.device)}")
        #self.device = torch.device("cuda:0")
        # Read model serialize/pt file
        self.model = MBartForConditionalGeneration.from_pretrained(model_dir)
        self.tokenizer = MBart50TokenizerFast.from_pretrained(model_dir)
        self.tokenizer.src_lang = "en_XX"
        self.model.to(self.device)
        self.model.eval()
        logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir))
        self.initialized = True

    def preprocess(self, data):
        """ Very basic preprocessing code - only tokenizes. 
            Extend with your own preprocessing steps as needed.
        """
        text = data[0].get("data")
        if text is None:
            text = data[0].get("body")
        sentences = text
        logger.info("Received text: '%s'", sentences)

        inputs = self.tokenizer(sentences,return_tensors="pt",padding=True).to(self.device)
        return inputs

    def inference(self, inputs):
        """
        Predict the class of a text using a trained transformer model.
        """
        # NOTE: This makes the assumption that your model expects text to be tokenized  
        # with "input_ids" and "token_type_ids" - which is true for some popular transformer models, e.g. bert.
        # If your transformer model expects different tokenization, adapt this code to suit 
        # its expected input format.
        # prediction = self.model(
        #     inputs['input_ids'].to(self.device), 
        #     token_type_ids=inputs['token_type_ids'].to(self.device)
        # )[0].argmax().item()
        #logger.info(f"Length of model input {len(inputs)}")
        translated_tokens = self.model.generate(**inputs, forced_bos_token_id=self.tokenizer.lang_code_to_id["zh_CN"])
        prediction = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
        #logger.info("Model predicted: '%s'", prediction)
        logger.info(f"From {len(inputs)} inputs, received prediction of {len(prediction)} of type {type(prediction)} and innerlist of type{type(prediction[0])} length {len(prediction[0])}")
        logger.info(prediction)
        return prediction
        # if isinstance(prediction,str):
        #     return prediction
        # elif isinstance(prediction,list):
        #     return tuple(prediction)

    def postprocess(self, inference_output):
        # TODO: Add any needed post-processing of the model predictions here
        logger.info(inference_output)
        return inference_output

I set batch to 32, min worker 3 and max worker 5, serving on 2 GPU (Nvidia 3090)

When I send a single sentence, I get the correct result, which is a string.

When I serve 3 sentences simultaneously, I get the following error:

number of batch response mismatched, expect: 3, got: 1.

Interestingly, even when I send a single sentence, torchserve still sees 2 input (as seen from the logger, From 2 inputs)

I can confirm that when not using pytorch serve, the inference part returns three output. Please see temp.py in the attached zip file.

Full .py and log files attached below, any help is much appreciated, I have been pulling my hair out. mbartmmt.zip

lxning commented 1 year ago

@revantemp3 The root cause is the preprocess does not handler batching. You can refer this example.

ericg108 commented 1 year ago

@revantemp3 The root cause is the preprocess does not handler batching. You can refer this example.

@lxning thanks, do we have a time expectation on this?