from abc import ABC
import json
import logging
import os
import torch
from transformers import MBartForConditionalGeneration, MBartTokenizer, MBart50TokenizerFast
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
class MbartHandler(BaseHandler, ABC):
"""
Transformers text classifier handler class. This handler takes a text (string) and
as input and returns the classification text based on the serialized transformers checkpoint.
"""
def __init__(self):
super(MbartHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
#self.device = "cuda:0"
logger.info(f"Using {str(self.device)}")
#self.device = torch.device("cuda:0")
# Read model serialize/pt file
self.model = MBartForConditionalGeneration.from_pretrained(model_dir)
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_dir)
self.tokenizer.src_lang = "en_XX"
self.model.to(self.device)
self.model.eval()
logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir))
self.initialized = True
def preprocess(self, data):
""" Very basic preprocessing code - only tokenizes.
Extend with your own preprocessing steps as needed.
"""
text = data[0].get("data")
if text is None:
text = data[0].get("body")
sentences = text
logger.info("Received text: '%s'", sentences)
inputs = self.tokenizer(sentences,return_tensors="pt",padding=True).to(self.device)
return inputs
def inference(self, inputs):
"""
Predict the class of a text using a trained transformer model.
"""
# NOTE: This makes the assumption that your model expects text to be tokenized
# with "input_ids" and "token_type_ids" - which is true for some popular transformer models, e.g. bert.
# If your transformer model expects different tokenization, adapt this code to suit
# its expected input format.
# prediction = self.model(
# inputs['input_ids'].to(self.device),
# token_type_ids=inputs['token_type_ids'].to(self.device)
# )[0].argmax().item()
#logger.info(f"Length of model input {len(inputs)}")
translated_tokens = self.model.generate(**inputs, forced_bos_token_id=self.tokenizer.lang_code_to_id["zh_CN"])
prediction = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
#logger.info("Model predicted: '%s'", prediction)
logger.info(f"From {len(inputs)} inputs, received prediction of {len(prediction)} of type {type(prediction)} and innerlist of type{type(prediction[0])} length {len(prediction[0])}")
logger.info(prediction)
return prediction
# if isinstance(prediction,str):
# return prediction
# elif isinstance(prediction,list):
# return tuple(prediction)
def postprocess(self, inference_output):
# TODO: Add any needed post-processing of the model predictions here
logger.info(inference_output)
return inference_output
I set batch to 32, min worker 3 and max worker 5, serving on 2 GPU (Nvidia 3090)
When I send a single sentence, I get the correct result, which is a string.
When I serve 3 sentences simultaneously, I get the following error:
number of batch response mismatched, expect: 3, got: 1.
Interestingly, even when I send a single sentence, torchserve still sees 2 input (as seen from the logger, From 2 inputs)
I can confirm that when not using pytorch serve, the inference part returns three output. Please see temp.py in the attached zip file.
Full .py and log files attached below, any help is much appreciated, I have been pulling my hair out.
mbartmmt.zip
Hi,
I used this custom handler to serve Huggingface's facebook mbart-large-50-many-to-many-mmt.
I set batch to 32, min worker 3 and max worker 5, serving on 2 GPU (Nvidia 3090)
When I send a single sentence, I get the correct result, which is a string.
When I serve 3 sentences simultaneously, I get the following error:
Interestingly, even when I send a single sentence, torchserve still sees 2 input (as seen from the logger, From 2 inputs)
I can confirm that when not using pytorch serve, the inference part returns three output. Please see temp.py in the attached zip file.
Full .py and log files attached below, any help is much appreciated, I have been pulling my hair out. mbartmmt.zip