NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.68k stars 992 forks source link

Broken output for int4 weight-only quantized version of merged Llama2 70b model with more layers #1191

Closed aikitoria closed 8 months ago

aikitoria commented 8 months ago

System Info

Who can help?

@Tracin

Information

Tasks

Reproduction

1) Launch nvcr.io/nvidia/tritonserver:24.01-trtllm-python-py3 container image 2) Install tensorrt-llm according to the readme: apt update apt install openmpi-bin libopenmpi-dev pip3 install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com 4) Also clone the repo so we can use the scripts git clone https://github.com/NVIDIA/TensorRT-LLM 5) Download the model from huggingface huggingface-cli download wolfram/miquliz-120b-v2.0 --local-dir /workspace/miquliz 6) Prepare the dependencies for checkpoint conversion script cd TensorRT-LLM/examples/llama pip install -r requirements.txt 7) Run the checkpoint conversion script as follows python3 convert_checkpoint.py --model_dir /workspace/miquliz/ --output_dir /workspace/miquliz-quantized/ --tp_size 4 --dtype float16 --use_weight_only --weight_only_precision int4 --fp8_kv_cache --enable_fp8 8) Copy the quantized model to inference server 9) Build engine as follows trtllm-build --checkpoint_dir /workspace/miquliz-quantized/ --output_dir /workspace/miquliz-engine/ --max_batch_size 1 --max_output_len 256 --weight_only_precision int4 --gemm_plugin float16 --paged_kv_cache enable --use_custom_all_reduce disable --multi_block_mode enable 10) Run engine as follows mpirun --allow-run-as-root -n 4 python3 ../run.py --max_output_len 256 --tokenizer_dir /workspace/miquliz/ --engine_dir /workspace/miquliz-engine/

Expected behavior

The engine builds without warning messages and generates sensible output

actual behavior

The engine builds, but many warning messages of this format are printed for ever tensor in every layer:

Running it produces garbage output:

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: "给给给给给给给给给给给给给给给给给给给给给给给给给给给给给给给给给给给给给给 (...)

Build engine log: log (1).txt Run engine log: log2.txt

additional notes

Since int4_awq quant format did not work at all, I am trying the basic 4 bit quant instead. Still experimenting with the other options to see if the issue is one of the settings but it is extremely slow to iterate with the 120B model.

If I use this smaller model instead, the warning messages are still generated, but the engine does not seem to be broken and generates reasonable output.

Tracin commented 8 months ago

Could you try with weight-only int8 to see if outputs are reasonable?

aikitoria commented 8 months ago

Yes, though it will not fit in the 4x 4090s, so I will have to run that one on a bigger server instead. Going to try 4x H100 (so FP8 works)

aikitoria commented 8 months ago

Engine built with following commands works correctly on 4x H100:

python convert_checkpoint.py --model_dir /workspace/miquliz/ --output_dir /workspace/miquliz-int8/ --tp_size 4 --dtype float16 --use_weight_only --weight_only_precision int8

trtllm-build --checkpoint_dir /workspace/miquliz-int8/ --output_dir /workspace/miquliz-engine-int8/ --max_batch_size 1 --max_output_len 256 --weight_only_precision int8 --gemm_plugin float16 --paged_kv_cache enable --use_custom_all_reduce disable --multi_block_mode enable

mpirun --allow-run-as-root -n 4 python3 ../run.py --max_output_len 256 --tokenizer_dir /workspace/miquliz/ --engine_dir /workspace/miquliz-engine-int8/

So that's the first test. Now to see which of the options I had above is breaking it.

aikitoria commented 8 months ago

Engine built with following commands works correctly on 4x H100:

python convert_checkpoint.py --model_dir /workspace/miquliz/ --output_dir /workspace/miquliz-int4/ --tp_size 4 --dtype float16 --use_weight_only --weight_only_precision int4

trtllm-build --checkpoint_dir /workspace/miquliz-int4/ --output_dir /workspace/miquliz-engine-int4/ --max_batch_size 1 --max_output_len 256 --weight_only_precision int4 --gemm_plugin float16 --paged_kv_cache enable --use_custom_all_reduce disable --multi_block_mode enable

mpirun --allow-run-as-root -n 4 python3 ../run.py --max_output_len 256 --tokenizer_dir /workspace/miquliz/ --engine_dir /workspace/miquliz-engine-int4/

So it seems the problem might not be the int4 quant but one of the other options?

phind-justin commented 8 months ago

Engine built with following commands works correctly on 4x H100:

python convert_checkpoint.py --model_dir /workspace/miquliz/ --output_dir /workspace/miquliz-int4/ --tp_size 4 --dtype float16 --use_weight_only --weight_only_precision int4

trtllm-build --checkpoint_dir /workspace/miquliz-int4/ --output_dir /workspace/miquliz-engine-int4/ --max_batch_size 1 --max_output_len 256 --weight_only_precision int8 --gemm_plugin float16 --paged_kv_cache enable --use_custom_all_reduce disable --multi_block_mode enable

mpirun --allow-run-as-root -n 4 python3 ../run.py --max_output_len 256 --tokenizer_dir /workspace/miquliz/ --engine_dir /workspace/miquliz-engine-int4/

So it seems the problem might not be the int4 quant but one of the other options?

hey it seems like you convert_checkpoint.py with int4 but build with int8. is that on purpose? how does that work?

aikitoria commented 8 months ago

Sorry, that must be a copy paste error. I was building with int4 for that one. Will fix.

aikitoria commented 8 months ago

I also tried building the same engine again that I had originally tried to run on 4x 4090. On 4x H100, it sometimes generates garbage (thought different garbage from the 4090), and sometimes throws a CUDA runtime error instead. Does H100 have more advanced error detection? Anyway, here goes:

python3 convert_checkpoint.py --model_dir /workspace/miquliz/ --output_dir /workspace/miquliz-int4-fp8/ --tp_size 4 --dtype float16 --use_weight_only --weight_only_precision int4 --fp8_kv_cache --enable_fp8

trtllm-build --checkpoint_dir /workspace/miquliz-int4-fp8/ --output_dir /workspace/miquliz-engine-int4-fp8/ --max_batch_size 1 --max_output_len 256 --weight_only_precision int4 --gemm_plugin float16 --paged_kv_cache enable --use_custom_all_reduce disable --multi_block_mode enable

mpirun --allow-run-as-root -n 4 python3 ../run.py --max_output_len 256 --tokenizer_dir /workspace/miquliz/ --engine_dir /workspace/miquliz-engine-int4-fp8/

Possible result:

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: "cabinetinet ebenistes ebenistes ebenistes ebenistes ebenistes ebenistes (...)"

Possible result:

transport.cc:174 NCCL WARN Cuda failure 'invalid argument'

Full log for the error (NCCL debug already enabled): log3.txt

Perhaps I'm passing wrong options and I'm not supposed to be combining --enable_fp8 with --weight_only_precision int4? I thought it would be required for --fp8_kv_cache but maybe it's not. Let me try again.

aikitoria commented 8 months ago

Hmm. Getting the same CUDA errors without --enable_fp8.

aikitoria commented 8 months ago

Trying int4 + int8 kv cache gave perhaps the most nonsensical error yet.

python3 convert_checkpoint.py --model_dir /workspace/miquliz/ --output_dir /workspace/miquliz-int4-int8/ --tp_size 4 --dtype float16 --use_weight_only --weight_only_precision int4 --int8_kv_cache

trtllm-build --checkpoint_dir /workspace/miquliz-int4-int8/ --output_dir /workspace/miquliz-engine-int4-int8/ --max_batch_size 1 --max_output_len 256 --weight_only_precision int4 --gemm_plugin float16 --paged_kv_cache enable --use_custom_all_reduce disable --multi_block_mode enable

mpirun --allow-run-as-root -n 4 python3 ../run.py --max_output_len 256 --tokenizer_dir /workspace/miquliz/ --engine_dir /workspace/miquliz-engine-int4-int8/

[TensorRT-LLM][INFO] Max tokens in paged KV cache: 856576. Allocating 61399367680 bytes.
[TensorRT-LLM][INFO] Max KV cache pages per sequence: 3
[TensorRT-LLM][INFO] Max tokens in paged KV cache: 856576. Allocating 61399367680 bytes.
[TensorRT-LLM][INFO] Max KV cache pages per sequence: 3
[TensorRT-LLM][INFO] Max tokens in paged KV cache: 856576. Allocating 61399367680 bytes.
[TensorRT-LLM][INFO] Max KV cache pages per sequence: 3
[TensorRT-LLM][INFO] Max tokens in paged KV cache: 856576. Allocating 61399367680 bytes.
[TensorRT-LLM][INFO] Max KV cache pages per sequence: 3
Traceback (most recent call last):
  File "/workspace/TensorRT-LLM/examples/llama/../run.py", line 567, in <module>
    main(args)
  File "/workspace/TensorRT-LLM/examples/llama/../run.py", line 419, in main
    outputs = runner.generate(
  File "/workspace/venv/lib/python3.10/site-packages/tensorrt_llm/runtime/model_runner_cpp.py", line 342, in generate
    self.session.generate(generation_output, generation_input,
RuntimeError: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaEventSynchronize(get()): misaligned address (/home/jenkins/agent/workspace/LLM/main/L0_MergeRequest/tensorrt_llm/cpp/include/tensorrt_llm/runtime/cudaEvent.h:66)
1       0x7f27a98c5f7d /workspace/venv/lib/python3.10/site-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0xc5f7d) [0x7f27a98c5f7d]
2       0x7f27a99ac4fa tensorrt_llm::runtime::GptSession::shouldStopSync(int, int, int) + 410
3       0x7f27a99ad976 tensorrt_llm::runtime::GptSession::executeGenerationStep(int, std::vector<tensorrt_llm::runtime::GenerationInput, std::allocator<tensorrt_llm::runtime::GenerationInput> > const&, std::vector<tensorrt_llm::runtime::GenerationOutput, std::allocator<tensorrt_llm::runtime::GenerationOutput> >&, std::vector<int, std::allocator<int> > const&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, std::vector<bool, std::allocator<bool> >&) + 1062
4       0x7f27a99b1dc4 tensorrt_llm::runtime::GptSession::generateBatched(std::vector<tensorrt_llm::runtime::GenerationOutput, std::allocator<tensorrt_llm::runtime::GenerationOutput> >&, std::vector<tensorrt_llm::runtime::GenerationInput, std::allocator<tensorrt_llm::runtime::GenerationInput> > const&, tensorrt_llm::runtime::SamplingConfig const&, std::function<void (int, bool)> const&) + 3620
5       0x7f27a99b3230 tensorrt_llm::runtime::GptSession::generate(tensorrt_llm::runtime::GenerationOutput&, tensorrt_llm::runtime::GenerationInput const&, tensorrt_llm::runtime::SamplingConfig const&) + 2096
6       0x7f2813997f69 /workspace/venv/lib/python3.10/site-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x44f69) [0x7f2813997f69]
7       0x7f28139817f0 /workspace/venv/lib/python3.10/site-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x2e7f0) [0x7f28139817f0]
8       0x564577c3410e python3(+0x15a10e) [0x564577c3410e]
9       0x564577c2aa7b _PyObject_MakeTpCall + 603
10      0x564577c42acb python3(+0x168acb) [0x564577c42acb]
11      0x564577c22cfa _PyEval_EvalFrameDefault + 24906
12      0x564577c427f1 python3(+0x1687f1) [0x564577c427f1]
13      0x564577c43492 PyObject_Call + 290
14      0x564577c1f5d7 _PyEval_EvalFrameDefault + 10791
15      0x564577c349fc _PyFunction_Vectorcall + 124
16      0x564577c1d26d _PyEval_EvalFrameDefault + 1725
17      0x564577c199c6 python3(+0x13f9c6) [0x564577c199c6]
18      0x564577d0f256 PyEval_EvalCode + 134
19      0x564577d3a108 python3(+0x260108) [0x564577d3a108]
20      0x564577d339cb python3(+0x2599cb) [0x564577d339cb]
21      0x564577d39e55 python3(+0x25fe55) [0x564577d39e55]
22      0x564577d39338 _PyRun_SimpleFileObject + 424
23      0x564577d38f83 _PyRun_AnyFileObject + 67
24      0x564577d2ba5e Py_RunMain + 702
25      0x564577d0202d Py_BytesMain + 45
26      0x7f29ac080d90 /lib/x86_64-linux-gnu/libc.so.6(+0x29d90) [0x7f29ac080d90]
27      0x7f29ac080e40 __libc_start_main + 128
28      0x564577d01f25 _start + 37

Except then I ran it again, and got a different result:

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: "给给给给给给给给给给给给给给给给给给给给给给给给 (...)

It's almost like it's trolling me.

aikitoria commented 8 months ago

It sure would be nice if there was more documentation on what situations I would want to use any of the arguments in, because at this point I am just guessing / trying permutations until something works. The --help isn't giving any useful information on all of these:

convert_checkpoint.py

trtllm-build

My end goal is still just to run Miquliz 120B as fast as possible on 4x 4090 GPUs.

aikitoria commented 8 months ago

Hmm. The CUDA errors are now also happening for my initial int4 quant that worked correctly before. Maybe my server broke somehow.

Tracin commented 8 months ago

Hi @aikitoria , please do not add --enable_fp8 and --fp8_kv_cache if you are using int4_weight_only, since they can not be used at the same time. We can try fix the accuracy problem first, it looks like int8_weight_only can produce correct results, right? So please try int4_awq with the command python ../quantization/quantize.py --model_dir ... --qformat int4_awq --output_dir ... As for the optimal building options, @kaiyux Can you help with this?

aikitoria commented 8 months ago

I wanted to use int4 AWQ originally, but it failed, because of this issue : https://github.com/NVIDIA/TensorRT-LLM/issues/1172

So I am trying the simpler int4 quant in the mean time. It has to be 4 bit because otherwise the model will not fit in the target system with 4x 4090.

Yeah, int8 + f16 kv cache produced a working engine, and int4 + f16 kv cache did also (though I have yet to test it on the 4090 system). Now I was experimenting if one of the quantized kv cache formats works with it.

I will stop trying to use fp8_kv_cache then. Would int8_kv_cache be expected to work with int4 quants? Could you make the scripts themselves report errors when incompatible options are chosen, or at least document the working combinations somewhere?

aikitoria commented 8 months ago

Does the engine have to be built on the same GPU that will run it? I thought I could copy it to the server with 4x 4090 for a quick test, but this does not appear to be the case, as it prints the following on trying to run it:

Traceback (most recent call last):
  File "/workspace/TensorRT-LLM/examples/llama/../run.py", line 567, in <module>
    main(args)
  File "/workspace/TensorRT-LLM/examples/llama/../run.py", line 416, in main
    runner = runner_cls.from_dir(**runner_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/model_runner_cpp.py", line 173, in from_dir
    session = GptSession(config=session_config,
RuntimeError: [TensorRT-LLM][ERROR] Assertion failed: Failed to deserialize cuda engine (/home/jenkins/agent/workspace/LLM/main/L0_MergeRequest/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmRuntime.cpp:68)
1       0x7fcbe320876b tensorrt_llm::common::throwRuntimeError(char const*, int, std::string const&) + 82
2       0x7fcbe321ac9a /usr/local/lib/python3.10/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0xc6c9a) [0x7fcbe321ac9a]
3       0x7fcbe33032e0 tensorrt_llm::runtime::GptSession::GptSession(tensorrt_llm::runtime::GptSession::Config const&, tensorrt_llm::runtime::GptModelConfig const&, tensorrt_llm::runtime::WorldConfig const&, void const*, unsigned long, std::shared_ptr<nvinfer1::ILogger>) + 720
4       0x7fcc27498ffb /usr/local/lib/python3.10/dist-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x71ffb) [0x7fcc27498ffb]
5       0x7fcc2746cd59 /usr/local/lib/python3.10/dist-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x45d59) [0x7fcc2746cd59]
6       0x7fcc274557f0 /usr/local/lib/python3.10/dist-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x2e7f0) [0x7fcc274557f0]
7       0x562bc623010e python3(+0x15a10e) [0x562bc623010e]
8       0x562bc6226a7b _PyObject_MakeTpCall + 603
9       0x562bc623eacb python3(+0x168acb) [0x562bc623eacb]
10      0x562bc623f635 _PyObject_Call + 277
11      0x562bc623b087 python3(+0x165087) [0x562bc623b087]
12      0x562bc6226e2b python3(+0x150e2b) [0x562bc6226e2b]
13      0x7fcc27454ea9 /usr/local/lib/python3.10/dist-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x2dea9) [0x7fcc27454ea9]
14      0x562bc6226a7b _PyObject_MakeTpCall + 603
15      0x562bc6220150 _PyEval_EvalFrameDefault + 30112
16      0x562bc623e7f1 python3(+0x1687f1) [0x562bc623e7f1]
17      0x562bc623f492 PyObject_Call + 290
18      0x562bc621b5d7 _PyEval_EvalFrameDefault + 10791
19      0x562bc62309fc _PyFunction_Vectorcall + 124
20      0x562bc621926d _PyEval_EvalFrameDefault + 1725
21      0x562bc62159c6 python3(+0x13f9c6) [0x562bc62159c6]
22      0x562bc630b256 PyEval_EvalCode + 134
23      0x562bc6336108 python3(+0x260108) [0x562bc6336108]
24      0x562bc632f9cb python3(+0x2599cb) [0x562bc632f9cb]
25      0x562bc6335e55 python3(+0x25fe55) [0x562bc6335e55]
26      0x562bc6335338 _PyRun_SimpleFileObject + 424
27      0x562bc6334f83 _PyRun_AnyFileObject + 67
28      0x562bc6327a5e Py_RunMain + 702
29      0x562bc62fe02d Py_BytesMain + 45
30      0x7fcd49f2bd90 /lib/x86_64-linux-gnu/libc.so.6(+0x29d90) [0x7fcd49f2bd90]
31      0x7fcd49f2be40 __libc_start_main + 128
32      0x562bc62fdf25 _start + 37
aikitoria commented 8 months ago

Hmm. Very strange result:

The engine built with

python convert_checkpoint.py --model_dir /workspace/miquliz/ --output_dir /workspace/miquliz-int4/ --tp_size 4 --dtype float16 --use_weight_only --weight_only_precision int4

trtllm-build --checkpoint_dir /workspace/miquliz-int4/ --output_dir /workspace/miquliz-engine-int4/ --max_batch_size 1 --max_output_len 256 --weight_only_precision int4 --gemm_plugin float16 --paged_kv_cache enable --use_custom_all_reduce disable --multi_block_mode enable

mpirun --allow-run-as-root -n 4 python3 ../run.py --max_output_len 256 --tokenizer_dir /workspace/miquliz/ --engine_dir /workspace/miquliz-engine-int4/

(last 2 steps done on the server running it)

Tracin commented 8 months ago

Please try to build with trtllm-build --checkpoint_dir /workspace/miquliz-int4/ --output_dir /workspace/miquliz-engine-int4/ --max_batch_size 1 --max_output_len 256 --gpt_attention_plugin float16

aikitoria commented 8 months ago

I grabbed a different H100 server in case the sudden CUDA errors are caused by the host somehow, setting it back up

aikitoria commented 8 months ago

@Tracin that immediately crashes with [TensorRT-LLM][ERROR] CUDA runtime error in error: peer access is not supported between these two devices. It seems we need at least --use_custom_all_reduce disable for it to try working

aikitoria commented 8 months ago

Rebuilt with that added. It now generates output on the 4x H100:

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: "chef in his home country before moving to London in 1998.

He worked at the two-Michelin-starred restaurant at Claridge’s hotel in Mayfair, before moving to the three-Michelin-starred Waterside Inn in Bray, Berkshire, in 2000.

Soyer then moved to the two-Michelin-starred Restaurant Gordon Ramsay in Chelsea, London, in 2002, where he worked as sous chef for four years.

In 2006, Soyer was appointed head chef at the two-Michelin-starred Le Gavroche in Mayfair, London, where he worked alongside Ramsay’s protégé, Michel Roux Jr.

Soyer left Le Gavroche in 2008 to take up the position of executive chef at the five-star, 107-bedroom Lanesborough hotel in Knightsbridge, London.

In 2010, Soyer was appointed executive chef at the five-star, 103-bedroom Berkeley hotel in Knightsbridge, London"

However, it does not work on the 4x 4090:

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: "_______________________________________________ (...)
aikitoria commented 8 months ago

log_h100.txt log_4090.txt

aikitoria commented 8 months ago

To make sure again that the library works at all on 4090, I ran the same commands for a much smaller model. Quantized it on 4x H100 as before, copied quantized model, built engine locally on each server.

It also works on 4x H100:

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: "chef in Paris, and then in London, where he became a naturalised British citizen. He was a pioneer of the "restaurant" concept, a place where people could go to eat a set meal, rather than a tavern or a private club.

Soyer's first restaurant, Soho's "Royal Etruria", was a success, and he went on to open a series of other establishments, including the "Café-Restaurant" in 1839, and the "Café-Restaurant-Concert" in 1840.

Soyer's fame spread, and he was invited to the United States, where he opened a series of restaurants, including the "Café-Restaurant-Concert" in New York.

Soyer was a prolific writer, and his cookbooks, such as "Shilling Cookery for the People" (1840) and "The Gastronomic Regenerator" (1841), were widely read.

Soyer's innovations in the kitchen, such as the use of the "Soy"

But it doesn't work on the 4x 4090 either:

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: ""

Are int4 quants broken on 4090?

aikitoria commented 8 months ago

Tried an int8 quant of that model too, but it's still producing garbage on 4090.

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: "chef in Paris before moving to London in 1828. He opened his first restaurant, Soyer’s Universal Cookery Establishment, in 1839, and went on to open several more, including the famous Reform Club in London.

Soyer was a pioneer in the field of culinary education, and in 1851 he opened the first cooking school in London, the Soyer’s Culinary School. He also wrote several cookbooks, including “The Modern Housewife, or, Menual of Economical Cookery” (1849) and “A Shilling Cookery for the People” (1855).

Soyer’s culinary innovations included the use of pressure cookers, which allowed for faster and more efficient cooking, and the introduction of new ingredients and techniques to the British palate. He is credited with popularizing dishes such as bouillon, consommé, and fricandeau, and is often referred to as the “father of modern cookery.”

In addition to his culinary achievements, Soyer was also a philanthropist, donating large sums of money to"
Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: "/******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ /******/ citiz /******/ /******/ /******/plaat<|im_start|><|im_start|>plaat /******/ /******/<|im_start|>ACHE driedimil_mvisitvisitimil_mvisitvisit_mintervalvisitvisit citiz /******/ /******/VBBYvisit XX /******/intervalhreIM_riACHEACHEvisitOps_rivisitimilimilBYBYBYBYvisitvisitvisitvisitvisitvisitvisitvisitclock Fer_rivisitvisitienceimilvisitclockimilimilclockimilimilimil driedhreclockocon dried driedimilimilimilimilimilimilimil_m_m_m_m_m_m_m_m_m_m_m_m_mACHEACHEACHEACHEvisitvisit_mvisit_m_m_m_mvisit_mvisi_visit_mvisit_mvisit_m_m_m_m_m_mimil_m_m_m_m_m__visitvisi_visit_m_mvisitimilimil_mvisitvisitvisitvisitvisitachen_m_mvisi_visitvisi_visit_mvisi_visi_visi__m_m_m___mvisitachenvisitachenvisitachenvisitvisitvisitvisitvisitvisitvisi_visitvisitvisitvisit__visitvisi_"
aikitoria commented 8 months ago

This time, I both quantized the model locally on the 4x 4090 server and built the engine there as before. Now it works??

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: "chef in Paris before moving to London in 1828. He opened his first restaurant, Soyer_s Universal Cookery Establishment, in 1839, and went on to open several more, including the famous Reform Club in London.

Soyer was a pioneer in the field of culinary education, and in 1851 he opened the first cooking school in London, the Soyer_s Culinary School. He also wrote several cookbooks, including _The Modern Housewife, or, Menual of Economical Cookery_ (1849) and _A Shilling Cookery for the People_ (1855).

Soyer_s culinary innovations included the use of pressure cookers, which allowed for faster and more efficient cooking, and the introduction of new ingredients and techniques to the British palate. He is credited with popularizing dishes such as bouillon, consomm_, and fricandeau, and is often referred to as the _father of modern cookery._

In addition to his culinary achievements, Soyer was also a philanthropist, donating large sums of money to"

Are the rankN.safetensors files from the conversion step not portable between devices?

aikitoria commented 8 months ago

But if that's the case, how can I convert the checkpoint for miquliz, given fp16 does not fit in memory on that server?

aikitoria commented 8 months ago

Could it be something like... the other server is confused about which rank should be which GPU, so it computes total garbage? I wonder because of these messages from the mpirun command:

[LOG_CAT_ML] You must specify a valid HCA device by setting:
-x HCOLL_MAIN_IB=<dev_name:port> or -x UCX_NET_DEVICES=<dev_name:port>.
If no device was specified for HCOLL (or the calling library), automatic device detection will be run.
In case of unfounded HCA device please contact your system administrator.
[LOG_CAT_ML] You must specify a valid HCA device by setting:
-x HCOLL_MAIN_IB=<dev_name:port> or -x UCX_NET_DEVICES=<dev_name:port>.
If no device was specified for HCOLL (or the calling library), automatic device detection will be run.
In case of unfounded HCA device please contact your system administrator.
[f0ba662bb3aa:11745] Error: coll_hcoll_module.c:310 - mca_coll_hcoll_comm_query() Hcol library init failed
[f0ba662bb3aa:11744] Error: coll_hcoll_module.c:310 - mca_coll_hcoll_comm_query() Hcol library init failed
[LOG_CAT_ML] You must specify a valid HCA device by setting:
-x HCOLL_MAIN_IB=<dev_name:port> or -x UCX_NET_DEVICES=<dev_name:port>.
If no device was specified for HCOLL (or the calling library), automatic device detection will be run.
In case of unfounded HCA device please contact your system administrator.
[f0ba662bb3aa:11743] Error: coll_hcoll_module.c:310 - mca_coll_hcoll_comm_query() Hcol library init failed
[LOG_CAT_ML] You must specify a valid HCA device by setting:
-x HCOLL_MAIN_IB=<dev_name:port> or -x UCX_NET_DEVICES=<dev_name:port>.
If no device was specified for HCOLL (or the calling library), automatic device detection will be run.
In case of unfounded HCA device please contact your system administrator.
[f0ba662bb3aa:11742] Error: coll_hcoll_module.c:310 - mca_coll_hcoll_comm_query() Hcol library init failed
Tracin commented 8 months ago

@aikitoria Looks like you can produce the right outputs, could you please summarize the situation here for me? Thanks.

aikitoria commented 8 months ago

At this point, my current understanding of the issue is:

This works: Quantize checkpoint locally -> build engine locally -> run

This does not work: Quantize checkpoint on large server -> copy quantized to small server -> build engine locally -> run

In all previous tests, I did the second way, because the small server does not have enough memory to load miquliz in fp16 for quantizing

Tracin commented 8 months ago

At this point, my current understanding of the issue is:

This works: Quantize checkpoint locally -> build engine locally -> run

This does not work: Quantize checkpoint on large server -> copy quantized to small server -> build engine locally -> run

In all previous tests, I did the second way, because the small server does not have enough memory to load miquliz in fp16 for quantizing

What is the exact GPU of server and local machine?

aikitoria commented 8 months ago

The large server:

root@904b72ce3f5b:/workspace/dolphin-int4# nvidia-smi
Fri Mar  1 08:50:27 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 PCIe               On  | 00000000:00:09.0 Off |                    0 |
| N/A   29C    P0              48W / 310W |      0MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 PCIe               On  | 00000000:00:0A.0 Off |                    0 |
| N/A   34C    P0              48W / 310W |      0MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 PCIe               On  | 00000000:00:0B.0 Off |                    0 |
| N/A   28C    P0              46W / 310W |      0MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 PCIe               On  | 00000000:00:0C.0 Off |                    0 |
| N/A   28C    P0              46W / 310W |      0MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

The small server:

root@C.9883904:/workspace/dolphin-int4$ nvidia-smi
Fri Mar  1 08:50:48 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.113.01             Driver Version: 535.113.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:A1:00.0 Off |                  Off |
| 39%   32C    P8              59W / 450W |      2MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        On  | 00000000:C1:00.0 Off |                  Off |
| 30%   28C    P8              61W / 450W |      2MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA GeForce RTX 4090        On  | 00000000:C2:00.0 Off |                  Off |
| 30%   25C    P8              58W / 450W |      2MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA GeForce RTX 4090        On  | 00000000:E1:00.0 Off |                  Off |
| 30%   25C    P8              57W / 450W |      2MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
aikitoria commented 8 months ago

Hmm, I just discovered the option --load_by_shard, perhaps it can quanitze it after all. Trying this

Tracin commented 8 months ago

That makes sense, different SM version requires different weights interleaving strategy, so you have to run under the same SM version.

aikitoria commented 8 months ago

Would be useful if it warned the user if this happens, rather than compute nonsense

aikitoria commented 8 months ago

Hmm. It looks like --load_by_shard does not work at all.

(venv) root@085728a23e72:/workspace/TensorRT-LLM/examples/llama# python convert_checkpoint.py --model_dir /workspace/miquliz/ --output_dir /workspace/miquliz-int4/ --tp_size 4 --dtype float16 --use_weight_only --weight_only_precision int4 --load_by_shard
[1709286155.573235] [085728a23e72:11191:0]            sock.c:502  UCX  WARN  unable to read somaxconn value from /proc/sys/net/core/somaxconn file
[TensorRT-LLM] TensorRT-LLM version: 0.9.0.dev2024022700
0.9.0.dev2024022700
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:15<00:00,  1.59it/s]
[03/01/2024-09:42:56] Some parameters are on the meta device device because they were offloaded to the cpu.
Traceback (most recent call last):
  File "/workspace/TensorRT-LLM/examples/llama/convert_checkpoint.py", line 1532, in <module>
    main()
  File "/workspace/TensorRT-LLM/examples/llama/convert_checkpoint.py", line 1508, in main
    covert_and_save(rank, convert_args)
  File "/workspace/TensorRT-LLM/examples/llama/convert_checkpoint.py", line 1474, in covert_and_save
    weights = load_from_hf_checkpoint(
  File "/workspace/venv/lib/python3.10/site-packages/tensorrt_llm/models/llama/weight.py", line 398, in load_from_hf_checkpoint
    split_v.transpose(), plugin_weight_only_quant_type)
TypeError: transpose() received an invalid combination of arguments - got (), but expected one of:
 * (int dim0, int dim1)
 * (name dim0, name dim1)
aikitoria commented 8 months ago

However, --load_model_on_cpu does! Now it succeeded quantizing and the engine works on 4x 4090 :D

Any chance of being able to combine it with the int8 or fp8 kv cache in the future?

aikitoria commented 8 months ago

@Tracin I'm curious why you say not to add enable_fp8 and fp8_kv_cache together with int4. I've done further tests using the dolphin model and it appears to work fine to combine the int4 with either of these:

int4 + int8 kv python convert_checkpoint.py --model_dir /workspace/dolphin/ --output_dir /workspace/dolphin-int4-int8/ --tp_size 4 --dtype float16 --use_weight_only --weight_only_precision int4 --int8_kv_cache trtllm-build --checkpoint_dir /workspace/dolphin-int4-int8/ --output_dir /workspace/dolphin-engine-int4-int8/ --max_beam_width 1 --max_batch_size 1 --max_output_len 256 --max_input_len 32512 --gpt_attention_plugin float16 --use_custom_all_reduce disable --multi_block_mode enable mpirun --allow-run-as-root -n 4 python3 ../run.py --max_output_len 256 --tokenizer_dir /workspace/dolphin/ --engine_dir /workspace/dolphin-engine-int4-int8/ --run_profiling

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0 Len 256]: "chef in Paris. He moved to London in 1837, where he opened a series of restaurants, the most famous of which was Soyer's at the Reform Club.

Soyer's culinary skills were highly regarded, and he was the first chef to be awarded a gold medal at the 1851 Exposition Universelle in Paris.

He was also the first chef to be knighted, in 1859, for his culinary achievements.

Soyer's cookbooks, including "The Modern Housekeeper and Cook, on a Large Scale" (1845) and "The Gastronomic Regenerator" (1847), were highly influential, and his recipes and techniques were widely adopted.

He is credited with inventing the Soyer's Pudding, a type of steamed pudding, and the Soyer's Chicken, a dish of chicken in a rich, creamy sauce.

Soyer's legacy continues to be celebrated in the culinary world, and his contributions to the field of gastronomy are still studied and admired by chefs and"
batch_size: 1, avg latency of 10 iterations: : 0.8093829154968262 sec

int4 + fp8 kv python convert_checkpoint.py --model_dir /workspace/dolphin/ --output_dir /workspace/dolphin-int4-fp8/ --tp_size 4 --dtype float16 --use_weight_only --weight_only_precision int4 --fp8_kv_cache --enable_fp8 trtllm-build --checkpoint_dir /workspace/dolphin-int4-fp8/ --output_dir /workspace/dolphin-engine-int4-fp8/ --max_beam_width 1 --max_batch_size 1 --max_output_len 256 --max_input_len 32512 --gpt_attention_plugin float16 --use_custom_all_reduce disable --multi_block_mode enable mpirun --allow-run-as-root -n 4 python3 ../run.py --max_output_len 256 --tokenizer_dir /workspace/dolphin/ --engine_dir /workspace/dolphin-engine-int4-fp8/ --run_profiling

Input [Text 0]: "<s> Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0 Len 256]: "chef in Paris, and then in London, where he became a naturalised British citizen. He was a pioneer of the "restaurant" concept, a place where people could go to eat a set meal, rather than a tavern or a private club.

Soyer's first restaurant, Soho's "Royal Etruria", was a success, and he went on to open a series of other establishments, including the "Café-Restaurant" in 1839, and the "Café-Restaurant-Concert" in 1840.

Soyer's fame spread, and he was invited to the United States, where he opened a series of restaurants, including the "Café-Restaurant-Concert" in New York.

Soyer was a prolific writer, and his cookbooks, such as "Shilling Cookery for the People" (1840) and "The Gastronomic Regenerator" (1841), were widely read.

Soyer's innovations in the kitchen, such as the use of the "Soy"
batch_size: 1, avg latency of 10 iterations: : 0.7649296760559082 sec

The only problem seems to be that I cannot create this quant for miquliz because selecting fp8 or int8 kv cache will try to load the entire model into memory at once to "calibrate" it (but my 4x 4090 server obviously doesn't have 240GB VRAM), --load_by_shard does not actually work at all, and using a larger server to do it also does not work because of that SM version thing you mentioned. Is there any solution? Can I run the "calibrate" step on H100 and then somehow save the intermediate values to continue on the 4090s..?

aikitoria commented 8 months ago

Wait, fp8 kv cache worked without calibration step

aikitoria commented 8 months ago

Closing this since my engine built successfully!

plt12138 commented 8 months ago

@aikitoria Hi, I got the same crashes [TensorRT-LLM][ERROR] CUDA runtime error in error: peer access is not supported between these two devices. . And as you say try to add --use_custom_all_reduce disable then it works correctly. Thanks! But what does this option do? The --help and Readme isn't giving any useful information. have you know exactly what it means now?

aikitoria commented 8 months ago

From what I understand, it's part of implementing tensor parallelism using nccl. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce

TensorRT-LLM has a custom implementation with better performance on systems supporting P2P (such as 8x H100 clusters with NVLINK). But on these consumer GPUs we have to disable it or it will crash.

plt12138 commented 8 months ago

From what I understand, it's part of implementing tensor parallelism using nccl. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce

TensorRT-LLM has a custom implementation with better performance on systems supporting P2P (such as 8x H100 clusters with NVLINK). But on these consumer GPUs we have to disable it or it will crash.

Thanks for your help. By the way in your experience, what the other options should i pay attention to when using consumer GPUs?