NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
7.34k stars 794 forks source link

There is a difference between the decoding result of Medusa and the source model #1795

Open skyCreateXian opened 1 week ago

skyCreateXian commented 1 week ago

System Info

Hardware: A10 cuda: 12.4 Driver: 535.54.03 Tensorrt-llm: 0.10.0 medusa base model: vicuna-7b-v1.3 medusa head model: FasterDecoding/medusa-vicuna-7b-v1.3

Who can help?

@kaiyux @by

`

convert medusa engine

python convert_checkpoint.py --model_dir /data/Medusa/vicuna-7b-v1.3 \ --medusa_model_dir /data/Medusa/lm_head \ --output_dir ./tllm_checkpoint_1gpu_medusa \ --dtype float16 \ --fixed_num_medusa_heads 4

trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_medusa \ --output_dir ./vicuna-medusa \ --gemm_plugin float16 \ --speculative_decoding_mode medusa \ --max_batch_size 8 `

`

convert source model

python ../llama/convert_checkpoint.py --model_dir /data/Medusa/vicuna-7b-v1.3 \ --output_dir ./tllm_checkpoint_1gpu_medusa \ --dtype float16

trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_medusa \ --output_dir ./vicuna-source \ --gemm_plugin float16 \ --max_batch_size 8

` There is a diff when index=31

index | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- medusa | 29992 | 1552 | 29883 | 12896 | 3626 | 29886 | 29889 | 510 | 13 | 13 | 1576 | 14062 | 323 | 6472 | 338 | 263 | 13436 | 5745 | 5001 | 393 | 4511 | 29879 | 9850 | 414 | 411 | 278 | 3186 | 30010 | 1 | 450 | 29871 | 29896 | 29929 | 29955 | 29945 | 6507 | 716 | 4823 | 525 | 6816 | 669 | 887 | 323 | 12966 | 838 | 650 | 29915 | 13 | 1576 | 29871 | 29896 | 29929 | 29955 | 29945 | 505 | 5492 | 263 | 716 | 4823 | 2000 | 525 | 6816 | 669 | 887 | 323 | 12966 | 838 | 650 | 4286 | 13 | 1576 | 5702 | 338 | 278 | 9281 | 304 | 367 | 26239 | 515 | 278 source | 29992 | 1552 | 29883 | 12896 | 3626 | 29886 | 29889 | 510 | 13 | 13 | 1576 | 14062 | 323 | 6472 | 338 | 263 | 13436 | 5745 | 5001 | 393 | 4511 | 29879 | 9850 | 414 | 411 | 278 | 3186 | 30010 | 1 | 450 | 937 | 931 | 306 | 4446 | 278 | 14064 | 376 | 1576 | 10213 | 4634 | 310 | 10705 | 24532 | 29891 | 1699 | 306 | 471 | 15469 | 491 | 278 | 325 | 3640 | 322 | 6382 | 262 | 1230 | 3186 | 393 | 278 | 15572 | 391 | 29892 | 10705 | 24532 | 29891 | 29892 | 14117 | 1573 | 29889 | 450 | 14064 | 29892 | 10624 | 491 | 322 | 380 | 23693 | 4111 | 624 | 5495

`

medusa and source engine decode sample

import ast import tensorrt_llm from tensorrt_llm.runtime import ModelRunnerCpp import torch import numpy from transformers import AutoTokenizer

engine_dir="source" tokenizer = AutoTokenizer.from_pretrained("/data/Medusa/vicuna-7b-v1.3") medusa_choices=None temperature=1.0 top_k=1 top_p=1.0 num_beams=1 max_new_tokens=80

prompt="hello" pad_id = tokenizer.pad_token_id end_id = tokenizer.eos_token_id

runtime_rank = tensorrt_llm.mpi_rank()

runner_kwargs = dict(engine_dir=engine_dir, lora_dir=None, rank=runtime_rank, debug_mode=True, lora_ckpt_source="hf", gpu_weights_percent=1)

if medusa_choices is not None: medusa_choices = ast.literal_eval(medusa_choices) assert temperature == 1.0, "Medusa should use temperature == 1.0" assert num_beams == 1, "Medusa should use num_beams == 1" runner_kwargs.update(medusa_choices=medusa_choices)

input_tensor=tokenizer.encode(prompt) batch_input_ids = [ torch.tensor(input_tensor, dtype=torch.int32) ] input_lengths = [x.size(0) for x in batch_input_ids] if True: runner_kwargs.update( max_batch_size=1, max_input_len=max(input_lengths), max_output_len=max_new_tokens, max_beam_width=num_beams, max_attention_window_size=None, sink_token_length=None ) runner = ModelRunnerCpp.from_dir(**runner_kwargs)

with torch.no_grad(): outputs = runner.generate( batch_input_ids, max_new_tokens=max_new_tokens, max_attention_window_size=None, sink_token_length=None, end_id=end_id, pad_id=pad_id, temperature=temperature, top_k=top_k, top_p=top_p, num_beams=num_beams, length_penalty=1.0, early_stopping=1, repetition_penalty=1.0, presence_penalty=0.0, frequency_penalty=0.0, stop_words_list=None, bad_words_list=None, output_cum_log_probs=False, output_log_probs=False, lora_uids=None, prompt_table=None, prompt_tasks=None, streaming=False, output_sequence_lengths=True, return_dict=True, medusa_choices=medusa_choices) torch.cuda.synchronize()

out_ids=outputs["output_ids"] out_ids=out_ids.tolist()[0][0][input_lengths[0]:]

print(out_ids) `

Information

Tasks

Reproduction

  1. Compile source engine and medusa engine
  2. Run source engine and medusa engine separately using the scripts in the documentation
  3. Compare the differences in output tokens

Expected behavior

The output token should be aligned with the source engine, and speculative decoding should not introduce accuracy loss

actual behavior

There is a diff between medusa and the output of the source model

additional notes

Through debugging gptManagerBenchmark, it is possible that diff may occur in the acceptDraftTokensByIdsWithPaths of Tensorrtllm/cpp/Tensorrtllm/kernels/decodedKernels.cu. When accepting the results of all medusa headers, the next token may be predicted incorrectly

skyCreateXian commented 1 week ago

prompt="How is the weather today?" Diff starts at index=10

index | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- medusa | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 13 | 29896 | 29889 | 1724 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 29906 | 29889 | 1128 | 338 | 278 | 14826 | 9826 | 29973 | 13 | 29941 | 29889 | 1724 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 29946 | 29889 | 1128 | 338 | 278 | 14826 | 9826 | 29973 | 13 | 29945 | 29889 | 1724 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 29953 | 29889 | 1128 | 338 | 278 | 14826 | 9826 | 29973 | 13 | 29955 | 29889 | 1724 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 29947 | 29947 | 29889 source | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973
nv-guomingz commented 1 week ago

@skyCreateXian Thanks for reporting this issue, we'll take a look internally and get back to u ASAP.

aspctu commented 1 week ago

+1, seeing this as well

nv-guomingz commented 3 days ago

Hi @skyCreateXian Would u please try the latest code base to see if the issue still exists?

@dongxuy04 and I can't reproduce your issue on our side.

skyCreateXian commented 3 days ago

@nv-guomingz By verifying the latest code, the issue of misaligned results when medusa top=1 has been resolved. Thank you for your support, and I will close the issue

skyCreateXian commented 2 days ago

@nv-guomingz The issue has not been completely resolved, and there are still differences between the medusa and base models Randomly select 200 items from the ShareGPT.V4.3_unfiltered_cleaned_split.json dataset

  1. On v0.10.0release, the full alignment rate is 30%
  2. On the latest branch, the complete alignment rate is 75%, and there are still 25% unaligned points Here is a diff case, diff starts at index=32

prompt="Summarize the main ideas of Brendon Burchard's Experts Academy into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients..."

index | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- medusa | 13 | 13 | 29930 | 450 | 13500 | 310 | 4969 | 263 | 7333 | 14982 | 322 | 10127 | 292 | 6743 | 761 | 408 | 385 | 17924 | 297 | 263 | 2702 | 302 | 4070 | 29889 | 13 | 29930 | 450 | 817 | 304 | 8569 | 373 | 263 | 2702 | 302 | 4070 | 322 | 4953 | 278 | 748 | 29899 | 517 | 17924 | 297 | 393 | 4038 | 29889 | 13 | 29930 | 450 | 13500 | 310 | 4969 | 21114 | 2793 | 322 | 5214 | 263 | 7881 | 2820 | 372 | 29889 | 13 source | 13 | 13 | 29930 | 450 | 13500 | 310 | 4969 | 263 | 7333 | 14982 | 322 | 10127 | 292 | 6743 | 761 | 408 | 385 | 17924 | 297 | 263 | 2702 | 302 | 4070 | 29889 | 13 | 29930 | 450 | 817 | 304 | 8569 | 373 | 16330 | 697 | 29915 | 29879 | 20026 | 322 | 13138 | 995 | 1549 | 2793 | 9999 | 292 | 322 | 5264 | 5745 | 29889 | 13 | 29930 | 450 | 26002 | 310 | 5214 | 263 | 7881 | 322 | 9926 | 3241 | 21702 | 411 | 697 | 29915 match | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0
skyCreateXian commented 2 days ago

@nv-guomingz This case is convenient for debugging because diff start=4 prompt="How can I make InitResponse available in an url in the code you showed above"

index | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- medusa | 29973 | 13 | 13 | 29902 | 864 | 304 | 671 | 278 | 10886 | 5103 | 297 | 385 | 3142 | 763 | 445 | 29901 | 13 | 13 | 29966 | 29874 | 2822 | 543 | 2344 | 29918 | 5327 | 29889 | 1420 | 1013 | 6644 | 6646 | 829 | 29874 | 29958 | 13 | 13 | 392 | 769 | 671 | 278 | 10886 | 5103 | 297 | 278 | 4544 | 934 | 29889 | 13 | 13 | 22550 | 29901 | 887 | 508 | 671 | 278 | 421 | 690 | 29889 | 9482 | 29952 | 740 | 304 | 4050 | 278 | 421 | 6644 | 5103 | 29952 | 1776 | 411 | 278 | 421 | 6644 | 5103 | 29952 | 1203 | 408 | 263 | 3030 | 2286 | 29889 source | 29973 | 13 | 13 | 22550 | 29901 | 887 | 508 | 671 | 278 | 421 | 2344 | 5103 | 29952 | 1203 | 297 | 278 | 3988 | 491 | 7797 | 5281 | 372 | 304 | 4663 | 322 | 769 | 9348 | 372 | 408 | 263 | 2346 | 3443 | 297 | 278 | 3988 | 29889 | 2266 | 29915 | 29879 | 385 | 1342 | 310 | 920 | 366 | 508 | 6623 | 278 | 421 | 2344 | 5103 | 29952 | 740 | 304 | 6176 | 445 | 29901 | 13 | 28956 | 7729 | 13 | 2220 | 2069 | 5103 | 29898 | 5327 | 29897 | 426 | 13 | 29871 | 1040 | 848 | 353 | 2933 | 29889 | 1272 | 29936 | 13 | 29871 | 1040 | 2069 | 5103 match | 1 | 1 | 1 | 0 | 0 |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |  
nv-guomingz commented 1 day ago

Got it. Just wanna double check the trt-llm version u're using now?

skyCreateXian commented 1 day ago

@nv-guomingz [TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024062500