deepjavalibrary / djl-demo

Demo applications showcasing DJL
https://demo.djl.ai
Apache License 2.0
311 stars 128 forks source link

Issue with batch sizes when performing inference #355

Closed yudhiesh closed 1 year ago

yudhiesh commented 1 year ago

Problem

I am deploying Llama2-7b on AWS Sagemaker Endpoints with DJL-Serving as the backend, it works as expected without much issues for a batch_size of 1 but when I increase the batch_size to a value >1 I start getting a lot of errors:

ai.djl.translate.TranslateException: Batch output size mismatch, expected: 4, actual: 1

I think the current inference code does not support batch_size > 1:

import os
from typing import Optional

import deepspeed
import torch
from djl_python import Input, Output
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

predictor = None

def get_model(properties):
    model_name = properties["model_id"]
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        use_safetensors=True,
    )
    model = deepspeed.init_inference(
        model, mp_size=properties["tensor_parallel_degree"]
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    generator = pipeline(
        task="text-generation", model=model, tokenizer=tokenizer, device=local_rank
    )
    return generator

def handle(inputs: Input) -> Optional[Output]:
    global predictor
    if not predictor:
        predictor = get_model(inputs.get_properties())

    if inputs.is_empty():
        # Model server makes an empty call to warm-up the model on startup
        return None
    data = inputs.get_as_json()
    text = data["text"]
    text_length = data["text_length"]
    min_length = data.get("min_length", 0)
    result = predictor(
        text, do_sample=True, min_length=min_length, max_length=text_length
    )
    return Output().add(result)

Is there an example to support both batch_size = 1 and batch_size > 1 for infrence?

I have noticed that there is an example using batch predict here, but the links to model.py and Deepspeed Handler do not work.

frankfliu commented 1 year ago

DJLServing support both client side batching and server side batching (also called dynamic batching)

Here is a simple example that handles server side batch: https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/src/test/resources/resnet18/model.py#L89

For text_generation tasks, you can json input with multiple queries:

{
    "inputs": [
        "The new movie that got Oscar this year",
        "The Large Language Model is"
    ],
    "parameters": {
        "max_new_tokens":256,
        "do_sample":true
    }
}

The built-in deepspeed handler do support both client and server batching

yudhiesh commented 1 year ago

@frankfliu thanks managed to get it to work but I noticed that inference is still done sequentially when I pass in multiple inputs as per the example you used.

I would expect that the inference would be done in parallel so the latency would scale in a constant manner with the number of inputs(assuming that the inputs/outputs can fit into memory) but instead it scales linearly.

I am not sure what settings I have to change as I am already altering the default values for batch_size and max_batch_delay.I am running inference on a AWS EC2 g5.2xlarge instance:

settings.properties

engine=DeepSpeed
option.tensor_parallel_degree=1
option.parallel_loading=true
option.s3url=s3://BUCKET
batch_size=8
max_batch_delay=1000

model.py

import os
from typing import Any, Dict, List, Optional, Tuple, Union

import deepspeed
import torch
from djl_python import Input, Output
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
    Pipeline,
    PreTrainedTokenizerBase,
)

predictor = None

def get_inference_toolkit(properties: Dict[str, Any]) -> Pipeline:
    model_name = properties["model_id"]
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        use_safetensors=True,
    )
    model = deepspeed.init_inference(
        model, mp_size=properties["tensor_parallel_degree"]
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    generator = pipeline(
        task="text-generation", model=model, tokenizer=tokenizer, device=local_rank
    )
    return generator

def format_input_for_task(input_values: Union[List[str], str]) -> List[List[str]]:
    if not isinstance(input_values, list):
        input_values = [input_values]

    batch_inputs = []
    for values in input_values:
        batch_inputs += [values]
    return batch_inputs

def handle(inputs: Input) -> Output:
    global predictor
    if not predictor:
        predictor = get_inference_toolkit(inputs.get_properties())

    if inputs.is_empty():
        # Model server makes an empty call to warm-up the model on startup
        return None
    content_type = inputs.get_property("Content-Type")
    input_data: List[str] = []
    input_size: List[int] = []
    model_kwargs: Dict[str, Any] = {}
    batches = inputs.get_batches()
    device = int(os.getenv("LOCAL_RANK", "0"))
    if content_type is not None and content_type.startswith("application/json"):
        first = True
        for item in batches:
            json_input = item.get_as_json()
            if isinstance(json_input, dict):
                input_size.append(len(json_input.get("inputs")))
                input_data.extend(format_input_for_task(json_input.pop("inputs")))
                if first:
                    model_kwargs = json_input.pop("parameters", {})
                    first = False
                else:
                    if model_kwargs != json_input.pop("parameters", {}):
                        return Output().error(
                            "In order to enable dynamic batching, all input batches must have the same parameters"
                        )
            else:
                input_size.append(len(json_input))
                input_data.extend(json_input)
    else:
        for item in batches:
            input_size.append(1)
            input_data.extend(item.get_as_string())

    outputs = Output()
    offset = 0
    results = predictor(input_data, **model_kwargs)
    batch_size = inputs.get_batch_size()
    for i in range(batch_size):
        result = results[offset : offset + input_size[i]]
        outputs.add(result, key=inputs.get_content().key_at(i))
        offset += input_size[i]
    outputs.add_property("content-type", "application/json")
    return outputs

Request Body

{
    "inputs": [
        "What is the purpose of life?",
        "What do you think lies ahead for humanity?",
        "What do you think lies ahead for humanity?"
    ],
    "parameters": {
        "max_length": 100
    }
}

Inference time for different input lengths:

Number of inputs: 1
CPU times: user 6.11 ms, sys: 435 µs, total: 6.55 ms
Wall time: 4.22 s

Number of inputs: 2
CPU times: user 6.04 ms, sys: 265 µs, total: 6.3 ms
Wall time: 7.31 s

Number of inputs: 3
CPU times: user 2.81 ms, sys: 3.74 ms, total: 6.55 ms
Wall time: 10.4 s
yudhiesh commented 1 year ago

Managed to fix it by specifying the batch_size in the pipeline:

results = predictor(input_data, batch_size=len(input_data), **model_kwargs)

One thing I am wondering is how does batching work when I have multiple requests coming in? Lets say I have 3 clients that make a request with 3 inputs at the same time and the batch_size of the serving.properties is set to 8.

Would 8/9 of those requests get grouped into a single prediction that is done in parallel? Then the remainder will be either batched up with other requests or done on its own if the max_batch_delay value is reached?

Another question I have is with the batch_size in the pipeline, would I need to set a value dynamically that depends on how much memory is currently left(including a buffer value) to prevent OOM issues?

frankfliu commented 1 year ago

@yudhiesh For dynamic batching, DJLServing frontend has a workload manager which put all HTTP requests into a job queue. And will try to group up to batch_size requests into a batch and send to backend to execute the inference. If in giving time max_batch_delay didn't collect enough request, DJLServing just send whatever available requests to backend.