turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.64k stars 279 forks source link

Streaming Issue with ExLlamaV2DynamicJobAsync #558

Closed remichu-ai closed 3 months ago

remichu-ai commented 3 months ago

I am having issue where streaming the result from ExLlamaV2DynamicJobAsync cause the stream rate to slow by half, however, when the generation reach halfway of the generation, then suddenly all the rest of the generation came out all at once. Meaning overall generation speed is retained, just that there could be something blocking. I use logger.info(f"{datetime.now()} {chunk}") to check the timing and confirm that all half of the generated text for the stream version comes out at the last second all at once. Most likely the issue is something i do wrongly but i couldnt figure it out for a few days already. Any help is appreciated.

My current code as follow for the non stream and stream version. It look almost identical to me but the non stream version work perfectly fine.

Non Stream

    async def generate(
        self,
        prompt: str,
        temperature: float = 0.01,
        lm_enforcer_parser: TokenEnforcerTokenizerData = None,
        stop_words: Union[List[str], str] = None,
        max_tokens: int = None,
        **kwargs,
    ) -> (str, GenerationStats):

        # ensure that generator is initialized
        if self.pipeline is None:
            self.pipeline = await self._get_pipeline_async()

        logger.info("----------------------Prompt---------------\n" + prompt)
        logger.debug("----------------------temperature---------\n" + str(temperature))

        # get generation setting
        settings = self._get_exllama_gen_settings(temperature)

        # convert prompt to token id
        input_ids = self.tokenizer.encode(prompt)
        self.validate_token_length(len(input_ids[0]))

        # format enforcer
        filters = None
        if lm_enforcer_parser:
            filters = [ExLlamaV2TokenEnforcerFilter(
                lm_enforcer_parser,
                self.tokenizer)
            ]

        # find stop conditions
        if stop_words:
            if isinstance(stop_words, str):
                stop_words = [stop_words]

            if not self.eos_token_str:
                raise Exception("EOS token not set in model_config")
            stop_conditions = self.eos_token_str + stop_words  # concat the 2 list
            logger.debug("stop_words: " + str(stop_conditions))
        else:
            stop_conditions = self.eos_token_str

        job_id = uuid.uuid4().hex

        async def run_job():
            job = ExLlamaV2DynamicJobAsync(
                generator=self.pipeline.generator,
                input_ids=input_ids,
                max_new_tokens=min(self.max_tokens - len(input_ids[0]),
                                   max_tokens) if max_tokens else self.max_tokens - len(input_ids[0]),
                gen_settings=settings,
                stop_conditions=stop_conditions,  #self.eos_token_id if self.eos_token_id else None,
                decode_special_tokens=True,
                filters=filters,
                #token_healing=True,
                identifier=job_id,
                #filter_prefer_eos=True,
            )

            generate_text = ""
            gen_stats = None
            async for result in job:
                chunk = result.get("text", "")
                logger.info(f"{datetime.now()} {chunk}")
                generate_text += chunk
                if result["eos"]:
                    eos = True
                    gen_stats = GenerationStats(
                        input_tokens_count=result["prompt_tokens"],
                        output_tokens_count=result["new_tokens"],
                        time_to_first_token=result["time_prefill"],
                        time_generate=result["time_generate"],
                    )

            return generate_text, gen_stats

        generate_text, gen_stats = await run_job()

        #logger.debug("----------------------LLM Raw Response---------------\n" + result["full_completion"])

        return generate_text, gen_stats

Streaming version:

    async def generate_stream(
        self,
        prompt: str,
        temperature: float = 0.01,
        lm_enforcer_parser: TokenEnforcerTokenizerData = None,
        stop_words: Union[List[str], str] = None,
        max_tokens: int = None,
        **kwargs,
    ) -> AsyncIterator:

        logger.info("----------------------Prompt---------------\n" + prompt)
        logger.debug("----------------------temperature---------\n" + str(temperature))

        # ensure that generator is initialized
        if self.pipeline is None:
            self.pipeline = await self._get_pipeline_async()

        # get generation setting
        settings = self._get_exllama_gen_settings(temperature)

        # convert prompt to token id
        input_ids = self.tokenizer.encode(prompt)
        self.validate_token_length(len(input_ids[0]))

        # format enforcer
        filters = None
        if lm_enforcer_parser:
            filters = [ExLlamaV2TokenEnforcerFilter(
                lm_enforcer_parser,
                #self.pipeline.lm_enforcer_tokenizer_data)
                self.tokenizer)
            ]

        # find stop conditions
        if stop_words:
            if isinstance(stop_words, str):
                stop_words = [stop_words]

            if not self.eos_token_str:
                raise Exception("EOS token not set in model_config")
            stop_conditions = self.eos_token_str + stop_words  # concat the 2 list
            logger.debug("stop_words: " + str(stop_conditions))
        else:
            stop_conditions = self.eos_token_str

        job_id = uuid.uuid4().hex

        async def run_job():

            job = ExLlamaV2DynamicJobAsync(
                generator=self.pipeline.generator,
                input_ids=input_ids,
                max_new_tokens=min(self.max_tokens - len(input_ids[0]),
                                   max_tokens) if max_tokens else self.max_tokens - len(input_ids[0]),
                gen_settings=settings,
                stop_conditions=stop_conditions,  # self.eos_token_id if self.eos_token_id else None,
                decode_special_tokens=True,
                filters=filters,
                #token_healing=False,
                identifier=job_id,
                # filter_prefer_eos=True,
            )

            generate_text = ""
            gen_stats = None
            eos = False
            async for result in job:
                if eos:
                    await job.cancel()

                chunk = result.get("text", "")
                logger.info(f"{datetime.now()} {chunk}")
                generate_text += chunk
                yield chunk

                if result["eos"]:
                    eos = True
                    gen_stats = GenerationStats(
                        input_tokens_count=result["prompt_tokens"],
                        output_tokens_count=result["new_tokens"],
                        time_to_first_token=result["time_prefill"],
                        time_generate=result["time_generate"],
                    )
                    logger.debug("----------------------LLM Raw Response---------------\n" + result["full_completion"])
                    yield gen_stats

        return run_job()
turboderp commented 3 months ago

I can't seem to reproduce this, and I'm a little unsure about how you're calling those async functions. I tried to reduce it to a more minimal example:

from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2DynamicGeneratorAsync, ExLlamaV2DynamicJobAsync, ExLlamaV2Sampler
import asyncio
import time

async def main():

    model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw"
    config = ExLlamaV2Config(model_dir)
    config.arch_compat_overrides()
    model = ExLlamaV2(config)
    cache = ExLlamaV2Cache(model, lazy = True)
    model.load_autosplit(cache, progress = True)
    tokenizer = ExLlamaV2Tokenizer(config)
    generator = ExLlamaV2DynamicGeneratorAsync(model, cache, tokenizer)

    async def generate_stream(
        prompt: str,
    ):
        input_ids = tokenizer.encode(prompt)

        # async def run_job():

        job = ExLlamaV2DynamicJobAsync(
            generator = generator,
            input_ids = input_ids,
            max_new_tokens = 500,
            stop_conditions = [tokenizer.eos_token_id],
            decode_special_tokens = True,
        )

        generate_text = ""
        async for result in job:
            chunk = result.get("text", "")
            generate_text += chunk
            yield chunk

            if result["eos"]:
                yield generate_text

        # await run_job()

    prompt = "Once upon a time,"

    p_time = time.time()
    for _ in range(5):
        async for x in generate_stream(prompt):
            s_time = time.time()
            latency = s_time - p_time
            p_time = s_time
            print(f"{latency:8.4f} - {repr(x)}")

if __name__ == "__main__":
    asyncio.run(main())

This isn't failing like you describe, so maybe it comes down to filters, specific stop conditions, or something? I probably need a more complete example to be able to reproduce it.

remichu-ai commented 3 months ago

Hi Turboderp,

Thank you for taking your time to help taking a look. I have confirmed that the issue is with how i consume the generator. I was using async for + yield to create generator, and then i have another async for + yield to return the result to client. This chaining of async for seems to have caused the issue and not related to the exllama engine at all.

I am changing my set up to be similar to Tabby which use a queue to store the generation for streaming instead and it seems to resolve the issue. I will close this issue soon.

remichu-ai commented 3 months ago

I have managed to fix the issue. The root cause is my fastapi streaming consumption code is slower than the speed generated text was put into the queue. I modified my queue to accumulate all the response in the queue first then send back all the item instead of sending item by item from the queue and that solved the issue.