NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.63k stars 984 forks source link

Llama3.1 70B smoothquant convert_checkpoint.py out of memory OOM #2202

Closed Hao-YunDeng closed 1 day ago

Hao-YunDeng commented 2 months ago

System Info

Driver Version: 535.154.05 CUDA Version: 12.5 NVIDIA A100-PCIE-40GB x 8 tensorrt 10.2.0 tensorrt_llm 0.12.0.dev2024072301 triton 2.3.1 torch 2.3.1 transformers 4.43.3 Python 3.10.12

Who can help?

@Tracin @byshiue

Information

Tasks

Reproduction

MODEL_PATH=/Meta-Llama-3.1-70B-Instruct/ DTYPE=bfloat16 TP=8 PP=1 MAX_NUM_TOKENS=130000 MAX_BATCH_SIZE=32 MAX_SEQ_LEN=128000 MAX_INPUT_LEN=124000

if [ "$MAX_INPUT_LEN" -eq 0 ]; then echo "!!! Please give correct MAX_INPUT_LEN. Default value MAX_INPUT_LEN=${MAX_INPUT_LEN}" exit fi

echo "Start conver transformers model..."

convert_model_path=trt_weight rm -rf $convert_model_path python3 tensorrt_llm/examples/llama/convert_checkpoint.py --model_dir ${MODEL_PATH} --output_dir ${convert_model_path} --dtype ${DTYPE} --smoothquant 0.5 --per_token --per_channel --tp_size ${TP} --pp_size ${PP}

Expected behavior

Successfully converting engine

actual behavior

[TensorRT-LLM] TensorRT-LLM version: 0.12.0.dev2024072301 0.12.0.dev2024072301 Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:49<00:00, 1.66s/it] /usr/local/lib/python3.10/dist-packages/datasets/load.py:1491: FutureWarning: The repository for ccdv/cnn_dailymail contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/ccdv/cnn_dailymail You can avoid this message in future by passing the argument trust_remote_code=True. Passing trust_remote_code=True will be mandatory to load this dataset from the next major release of datasets. warnings.warn( Downloading builder script: 100%|████████████████████████████████████████████████████████████████████████████████| 9.27k/9.27k [00:00<00:00, 14.2MB/s] Downloading readme: 100%|████████████████████████████████████████████████████████████████████████████████████████| 13.9k/13.9k [00:00<00:00, 22.4MB/s] Downloading data: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 159M/159M [00:01<00:00, 109MB/s] Downloading data: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 376M/376M [00:03<00:00, 109MB/s] Downloading data: 2.11MB [00:00, 73.9MB/s]
Downloading data: 46.4MB [00:00, 90.7MB/s]
Downloading data: 2.43MB [00:00, 72.6MB/s]
Generating train split: 287113 examples [00:50, 5707.22 examples/s] Generating validation split: 13368 examples [00:02, 5725.56 examples/s] Generating test split: 11490 examples [00:01, 5865.29 examples/s] calibrating model: 0%| | 0/512 [00:00<?, ?it/s]We detected that you are passing past_key_values as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate Cache class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache) calibrating model: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 512/512 [05:36<00:00, 1.52it/s] Traceback (most recent call last): File "/code/tensorrt_llm/examples/llama/convert_checkpoint.py", line 510, in main() File "/code/tensorrt_llm/examples/llama/convert_checkpoint.py", line 502, in main convert_and_save_hf(args) File "/code/tensorrt_llm/examples/llama/convert_checkpoint.py", line 387, in convert_and_save_hf LLaMAForCausalLM.quantize( File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/model.py", line 411, in quantize convert.quantize(hf_model_dir, File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/convert.py", line 1358, in quantize weights = load_weights_from_hf_model( File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/convert.py", line 1238, in load_weights_from_hf_model convert_layer(l) File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/convert.py", line 1107, in convert_layer get_tllm_linear_sq_weight( File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/convert.py", line 585, in get_tllm_linear_sq_weight original_weights = torch.Tensor(vals["weight.int8.col"]).cuda() torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 224.00 MiB. GPU

additional notes

We also tried --load_model_on_cpu. But the converting was stuck even overnight

[TensorRT-LLM] TensorRT-LLM version: 0.12.0.dev2024072301 0.12.0.dev2024072301 Loading checkpoint shards: 100%|██████████| 30/30 [00:16<00:00, 1.86it/s] /usr/local/lib/python3.10/dist-packages/datasets/load.py:1491: FutureWarning: The repository for ccdv/cnn_dailymail contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/ccdv/cnn_dailymail You can avoid this message in future by passing the argument trust_remote_code=True. Passing trust_remote_code=True will be mandatory to load this dataset from the next major release of datasets. warnings.warn( calibrating model: 0%| | 0/512 [00:00<?, ?it/s]We detected that you are passing past_key_values as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate Cache class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)

Hao-YunDeng commented 2 months ago

@Tracin @byshiue is there a way of converting checkpoint using multinode? Is so, how do we do that?

fan-niu commented 2 months ago

@Tracin same issue, can you help to take a look this problem? thanks

Hao-YunDeng commented 1 month ago

@kaiyux can you or anyone please help look into this? We are blocked by this issue

Hao-YunDeng commented 1 month ago

update: we seem to fix the problem by modifying tensorrt_llm/models/llama/convert.py:

if per_token:
    if per_channel:
        original_weights = torch.Tensor(vals["weight.int8.col"])# .cuda() # commented out
    else:
        original_weights = torch.Tensor(vals["weight.int8"])# .cuda() # commented out
    local_dim = original_weights.shape[0]
    head_size = (original_weights.shape[1] - local_dim) // 2

    if multi_query_mode:
        cur_weights = multi_query_split(original_weights, local_dim,
                                        head_size, tensor_parallel, rank)
    else:
        cur_weights = torch.chunk(original_weights,
                                  tensor_parallel,
                                  dim=cat_dim)[rank]

    cur_weights = cur_weights.cuda() # added
github-actions[bot] commented 2 weeks ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 15 days."

github-actions[bot] commented 1 day ago

This issue was closed because it has been stalled for 15 days with no activity.