mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
19.2k stars 1.58k forks source link

[Bug] Speculative decoding results in huge memory usage #1827

Closed srikanthsrnvs closed 8 months ago

srikanthsrnvs commented 8 months ago

🐛 Bug

When attempting to test speculative decoding using the Speculative decoding predefined test, I get a huge memory usage which results in an OOM on my device

To Reproduce

Steps to reproduce the behavior:

  1. Run the speculative decoding test file with 2 models LLaMA 7B as the SSM, and LLaMA 13B as the LLM
import asyncio
from typing import List

from mlc_chat.serve import (
    AsyncThreadedEngine,
    EngineMode,
    GenerationConfig,
    KVCacheConfig,
)
from mlc_chat.serve.engine import ModelInfo

prompts = [
    "What is the meaning of life?",
    "Introduce the history of Pittsburgh to me. Please elaborate in detail.",
    "Write a three-day Seattle travel plan. Please elaborate in detail.",
    "What is Alaska famous of? Please elaborate in detail.",
    "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.",
    "What are the necessary components to assemble a desktop computer? Please elaborate in detail.",
    "Why is Vitamin D important to human beings? Please elaborate in detail.",
    "Where is milk tea originated from? Please elaborate in detail.",
    "Where is the southernmost place in United States? Please elaborate in detail.",
    "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.",
]

async def test_engine_generate():
    # Initialize model loading info and KV cache config
    ssm = ModelInfo(
        "models--meta-llama--Llama-2-7b-hf/snapshots/8cca527612d856d7d32bd94f8103728d614eb852/MLC",
        model_lib_path="/home/truffle/models/models--meta-llama--Llama-2-7b-hf/snapshots/8cca527612d856d7d32bd94f8103728d614eb852/MLC/model.so",
    )
    llm = ModelInfo(
        "models--meta-llama--Llama-2-13b-hf/snapshots/dc1d3b3bfdb69df26f8fc966c16353274b138c55/MLC",
        model_lib_path="/home/truffle/models/models--meta-llama--Llama-2-13b-hf/snapshots/dc1d3b3bfdb69df26f8fc966c16353274b138c55/MLC/model.so",
    )
    kv_cache_config = KVCacheConfig(page_size=16)
    engine_mode = EngineMode(enable_speculative=True)
    # Create engine
    async_engine = AsyncThreadedEngine([llm, ssm], kv_cache_config, engine_mode)

    num_requests = 10
    max_tokens = 256
    generation_cfg = GenerationConfig(max_tokens=max_tokens)

    outputs: List[str] = ["" for _ in range(num_requests)]

    async def generate_task(
        async_engine: AsyncThreadedEngine,
        prompt: str,
        generation_cfg: GenerationConfig,
        request_id: str,
    ):
        print(f"generate task for request {request_id}")
        rid = int(request_id)
        async for delta_text, num_delta_tokens, finish_reason in async_engine.generate(
            prompt, generation_cfg, request_id=request_id
        ):
            outputs[rid] += delta_text

    tasks = [
        asyncio.create_task(
            generate_task(async_engine, prompts[i], generation_cfg, request_id=str(i))
        )
        for i in range(num_requests)
    ]

    await asyncio.gather(*tasks)

    # Print output.
    print("All finished")
    for req_id, output in enumerate(outputs):
        print(f"Prompt {req_id}: {prompts[req_id]}")
        print(f"Output {req_id}:{output}\n")

    async_engine.terminate()
    del async_engine

if __name__ == "__main__":
    asyncio.run(test_engine_generate())

[2024-02-23 03:10:22] INFO engine.py:205: Estimated KVCacheConfig "max_total_sequence_length": 35952. [2024-02-23 03:10:22] INFO engine.py:210: Estimated total single GPU memory usage: 61050.11 MB (Parameters: 9462.36 MB. KVCache: 50669.85 MB. Temporary buffer: 444.78 MB)

Expected behavior

I expect it to use a lot less memory since the models are just 7B and 13B models

Environment

anibohara2000 commented 8 months ago

Currently in the serving engine, if you don't provide max_total_sequence_length in the KVCacheConfig, it tries to determine it in a way to use the total available GPU memory. If you want to lower the GPU memory usage, pass the max_total_sequence_length argument in KVCacheConfig to an appropriate smaller value