Lightning-AI / LitServe

Lightning-fast serving engine for any AI model of any size. Flexible. Easy. Enterprise-scale.
https://lightning.ai/docs/litserve
Apache License 2.0
2.5k stars 158 forks source link

unexpected output for HF model with matching #294

Closed Borda closed 1 month ago

Borda commented 1 month ago

🐛 Bug

To Reproduce

without batching all works as expected

{'output': 'What is the capital of Greece?\n\nAthens.'}

but with batch, it returns just the first character

{'output': 'W'}

Code sample

import torch
import litserve as ls
from transformers import AutoTokenizer, AutoModelForCausalLM

class JambaLitAPI(ls.LitAPI):

    def __init__(
        self,
        model_name: str = "ai21labs/AI21-Jamba-1.5-Mini",
        max_new_tokens: int = 100
    ):
        self.model_name = model_name
        self.max_new_tokens = max_new_tokens

    def setup(self, device):
        # Load the model and tokenizer from Hugging Face Hub
        # For example, using the `distilbert-base-uncased-finetuned-sst-2-english` model for sentiment analysis
        # You can replace the model name with any model from the Hugging Face Hub
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name, torch_dtype=torch.bfloat16, device_map="auto", use_mamba_kernels=False
        ).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, legacy=False)

    def decode_request(self, request):
        # Extract text from request
        # This assumes the request payload is of the form: {'input': 'Your input text here'}
        return request["input"]

    def predict(self, text):
        print(text)
        # Use the loaded pipeline to perform inference
        inputs = self.tokenizer(text, return_tensors='pt')
        input_ids = inputs.to(self.model.device)["input_ids"]
        print(input_ids)
        output_ids = self.model.generate(
            input_ids,
            max_new_tokens=self.max_new_tokens
        )[0]
        print(output_ids)
        text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
        print(text)
        return text

    def encode_response(self, output):
        # Format the output from the model to send as a response
        # This example sends back the label and score of the prediction
        return {"output": output}

if __name__ == "__main__":
    # Create an instance of your API
    api = JambaLitAPI()
    # Start the server, specifying the port
    server = ls.LitServer(api, accelerator="cuda", devices=1, max_batch_size=4)
    # print("run the server...")
    server.run(port=8000)

Expected behavior

Environment

If you published a Studio with your bug report, we can automatically get this information. Otherwise, please describe:

Additional context

aniketmaurya commented 1 month ago

context for the issue:

With batching enabled, we expect the predict method to return a list of predictions. And when user don't implement LitAPI.unbatch we wrap the output to list(output) before sending to the endode_response method.

The list(prediction_output) in this case was a string which got split by character. So, we need to warn the users in this case.

grumpyp commented 1 month ago

I'd like to fix that. @aniketmaurya

Will raise a PR. Thanks.

aniketmaurya commented 1 month ago

@grumpyp looking forward! pls let me know if you have any question

grumpyp commented 1 month ago

context for the issue:

With batching enabled, we expect the predict method to return a list of predictions. And when user don't implement LitAPI.unbatch we wrap the output to list(output) before sending to the endode_response method.

The list(prediction_output) in this case was a string which got split by character. So, we need to warn the users in this case.

Would you want to prevent this to happen, or can you think of cases where this is needed - so I'll only add a warning in case output is a string when batching is enabled? I could additionally introcude a parameter to enforce list-like outputs.

aniketmaurya commented 1 month ago

@grumpyp let's just print a warning for now and observe any new issue on this. You can add the logic here.