NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.64k stars 986 forks source link

Seeing GPU OOM errors when using `--paged_kv_cache` option #283

Open cody-moveworks opened 1 year ago

cody-moveworks commented 1 year ago

Opening a new issue as #237 was closed prematurely.

It seems that engines built using the --paged_kv_cache flag leak GPU memory. Below is a minimal reproducible example code that can be used to trigger a GPU out-of-memory error. The ENGINE_DIR and TOKENIZER_DIR variables should be changed accordingly.

import json
import os
from pathlib import Path

import numpy as np
import torch
from transformers import LlamaTokenizer

import tensorrt_llm
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
from tqdm import tqdm

#### Change these values as needed ####
ENGINE_DIR = '/tmp/CodeLlama/7B/trt_engines/fp16_smoothquant_int8_kv_cache/1-gpu'
TOKENIZER_DIR = '/tmp/CodeLlama/7B/hf'
LOG_LEVEL = 'error'
#######################################

EOS_TOKEN = 2
PAD_TOKEN = 2

MAX_CONTEXT_LENGTH = 2048
MAX_NEW_TOKENS = 512
BEAM_WIDTH = 1

def main():
    tensorrt_llm.logger.set_level(LOG_LEVEL)

    print('Loading model and tokenizer...')
    config_path = os.path.join(ENGINE_DIR, 'config.json')
    model_config, tp_size, pp_size, dtype = read_config(config_path)
    world_size = tp_size * pp_size

    runtime_rank = tensorrt_llm.mpi_rank()
    runtime_mapping = tensorrt_llm.Mapping(
        world_size,
        runtime_rank,
        tp_size=tp_size,
        pp_size=pp_size,
    )
    torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)

    engine_name = get_engine_name('llama', dtype, tp_size, pp_size, runtime_rank)
    engine_path = os.path.join(ENGINE_DIR, engine_name)
    with open(engine_path, 'rb') as infile:
        engine_buffer = infile.read()

    generation_session = tensorrt_llm.runtime.GenerationSession(
        model_config,
        engine_buffer,
        runtime_mapping,
        debug_mode=False,
        debug_tensors_to_save=None,
    )

    tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_DIR, legacy=False)

    print('Creating toy inputs...')
    input_ids = [tokenizer.encode(' '.join(['the'] * MAX_CONTEXT_LENGTH), add_special_tokens=False)]
    input_lengths = torch.IntTensor([len(x) for x in input_ids]).cuda()

    if model_config.remove_input_padding:
        input_ids = torch.IntTensor(np.concatenate(input_ids)).cuda().unsqueeze(0)
    else:
        input_ids = torch.nested.to_padded_tensor(
            torch.nested.nested_tensor(
                input_ids,
                dtype=torch.int32,
            ),
            end_id,
        ).cuda()

    print('Calling the model repeatedly until GPU OOM...')
    sampling_config = SamplingConfig(
        end_id=EOS_TOKEN,
        pad_id=PAD_TOKEN,
        num_beams=BEAM_WIDTH,
    )
    for _ in tqdm(range(100)):
        generation_session.setup(
            batch_size=input_lengths.size(0),
            max_context_length=torch.max(input_lengths).item(),
            max_new_tokens=MAX_NEW_TOKENS,
            beam_width=BEAM_WIDTH,
        )
        outputs = generation_session.decode(
            input_ids,
            input_lengths,
            sampling_config,
            output_sequence_lengths=True,
            return_dict=True,
        )
        torch.cuda.synchronize()

# Copied from `examples/llama/run.py`
def read_config(config_path: Path):
    with open(config_path, 'r') as f:
        config = json.load(f)
    use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']
    remove_input_padding = config['plugin_config']['remove_input_padding']
    dtype = config['builder_config']['precision']
    tp_size = config['builder_config']['tensor_parallel']
    pp_size = config['builder_config']['pipeline_parallel']
    world_size = tp_size * pp_size
    assert world_size == tensorrt_llm.mpi_world_size(), \
        f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
    num_heads = config['builder_config']['num_heads'] // tp_size
    hidden_size = config['builder_config']['hidden_size'] // tp_size
    vocab_size = config['builder_config']['vocab_size']
    num_layers = config['builder_config']['num_layers']
    num_kv_heads = config['builder_config'].get('num_kv_heads', num_heads)
    paged_kv_cache = config['plugin_config']['paged_kv_cache']
    tokens_per_block = config['plugin_config']['tokens_per_block']
    quant_mode = QuantMode(config['builder_config']['quant_mode'])
    if config['builder_config'].get('multi_query_mode', False):
        tensorrt_llm.logger.warning(
            "`multi_query_mode` config is deprecated. Please rebuild the engine."
        )
        num_kv_heads = 1
    num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
    use_custom_all_reduce = config['plugin_config'].get('use_custom_all_reduce',
                                                        False)

    model_config = ModelConfig(num_heads=num_heads,
                               num_kv_heads=num_kv_heads,
                               hidden_size=hidden_size,
                               vocab_size=vocab_size,
                               num_layers=num_layers,
                               gpt_attention_plugin=use_gpt_attention_plugin,
                               paged_kv_cache=paged_kv_cache,
                               tokens_per_block=tokens_per_block,
                               remove_input_padding=remove_input_padding,
                               dtype=dtype,
                               quant_mode=quant_mode,
                               use_custom_all_reduce=use_custom_all_reduce)

    return model_config, tp_size, pp_size, dtype

# Copied from `examples/llama/build.py`
def get_engine_name(model, dtype, tp_size, pp_size, rank):
    if pp_size == 1:
        return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
    return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size,
                                                  pp_size, rank)

if __name__ == '__main__':
    main()

I am running my tests in a Docker container running an image built using the Dockerfile provided by this GitHub repo (i.e. by running make -C docker release_build). I am using an NVIDIA A100 40G GPU for my tests. I used Meta's pre-trained CodeLlama-7b-instruct model for the tests. The /tmp/CodeLlama/7B/hf is a directory containing both the Hugging Face PyTorch model and tokenizer files for the model.

We first convert the Hugging Face PyTorch model checkpoint to a FasterTransformer model checkpoint to prepare it for engine building with SmoothQuant and int8 KV cache features enabled:

python examples/llama/hf_llama_convert.py \
    --in-file /tmp/CodeLlama/7B/hf \
    --out-dir /tmp/CodeLlama/7B/ft/fp16_smoothquant_int8_kv_cache \
    --tensor-parallelism 1 \
    --processes $(nproc) \
    --storage-type fp16 \
    --smoothquant 0.5 \
    --calibrate-kv-cache

If the engine is built with this command, then running the script above does not result in GPU OOM.

python examples/llama/build.py \
    --ft_model_dir /tmp/CodeLlama/7B/ft/fp16_smoothquant_int8_kv_cache/1-gpu/ \
    --output_dir /tmp/CodeLlama/7B/trt_engines/fp16_smoothquant_int8_kv_cache_paged_kv_cache/1-gpu/ \
    --dtype float16 \
    --remove_input_padding \
    --use_gpt_attention_plugin float16 \
    --use_gemm_plugin float16 \
    --use_rmsnorm_plugin float16 \
    --enable_context_fmha \
    --int8_kv_cache \
    --use_smooth_quant \
    --per_token \
    --per_channel \
    --rotary_base 1000000 \
    --vocab_size 32016 \
    --max_input_len 8192 \
    --max_output_len 4096

If you add --paged_kv_cache flag when building the engine, then running the script above leads to GPU OOM.

python examples/llama/build.py \
    --ft_model_dir /tmp/CodeLlama/7B/ft/fp16_smoothquant_int8_kv_cache/1-gpu/ \
    --output_dir /tmp/CodeLlama/7B/trt_engines/fp16_smoothquant_int8_kv_cache_paged_kv_cache/1-gpu/ \
    --dtype float16 \
    --remove_input_padding \
    --use_gpt_attention_plugin float16 \
    --use_gemm_plugin float16 \
    --use_rmsnorm_plugin float16 \
    --enable_context_fmha \
    --paged_kv_cache \
    --int8_kv_cache \
    --use_smooth_quant \
    --per_token \
    --per_channel \
    --rotary_base 1000000 \
    --vocab_size 32016 \
    --max_input_len 8192 \
    --max_output_len 4096
Tlntin commented 1 year ago

me too, so you need to use triton to deploy it, not python runtime.

juney-nvidia commented 1 year ago

@cody-moveworks

Thanks for summarizing the concrete steps of reproducing your issue.

As @Tlntin said, you can try with the C++ Runtime firstly to see whether the OOM issue still exist.

In the meanwhile, we will also start the investigation of this issue.

Thanks June

gesanqiu commented 12 months ago

Any progress here? BTW, when will the 0.6.0 will be released?

daxiongshu commented 11 months ago

me too, so you need to use triton to deploy it, not python runtime.

Could you please elaborate on why triton doesn't have this issue? Does triton deploy c++ runtime? Thanks!

Tlntin commented 11 months ago

Could you please elaborate on why triton doesn't have this issue

i don't know why.

Does triton deploy c++ runtime?

yes

byshiue commented 11 months ago

Any update?

ZJLi2013 commented 3 months ago

still see the issues with python benchmark. when disable paged_kv_cache can run llama2-7b with bs=32, in=128, out=2048, while enable paged_kv_cache give oom.