Open skyCreateXian opened 1 week ago
prompt="How is the weather today?" Diff starts at index=10
index | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- medusa | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 13 | 29896 | 29889 | 1724 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 29906 | 29889 | 1128 | 338 | 278 | 14826 | 9826 | 29973 | 13 | 29941 | 29889 | 1724 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 29946 | 29889 | 1128 | 338 | 278 | 14826 | 9826 | 29973 | 13 | 29945 | 29889 | 1724 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 29953 | 29889 | 1128 | 338 | 278 | 14826 | 9826 | 29973 | 13 | 29955 | 29889 | 1724 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 29947 | 29947 | 29889 source | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973 | 13 | 5618 | 338 | 278 | 14826 | 763 | 9826 | 29973@skyCreateXian Thanks for reporting this issue, we'll take a look internally and get back to u ASAP.
+1, seeing this as well
Hi @skyCreateXian Would u please try the latest code base to see if the issue still exists?
@dongxuy04 and I can't reproduce your issue on our side.
@nv-guomingz By verifying the latest code, the issue of misaligned results when medusa top=1 has been resolved. Thank you for your support, and I will close the issue
@nv-guomingz The issue has not been completely resolved, and there are still differences between the medusa and base models Randomly select 200 items from the ShareGPT.V4.3_unfiltered_cleaned_split.json dataset
prompt="Summarize the main ideas of Brendon Burchard's Experts Academy into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients..."
index | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- medusa | 13 | 13 | 29930 | 450 | 13500 | 310 | 4969 | 263 | 7333 | 14982 | 322 | 10127 | 292 | 6743 | 761 | 408 | 385 | 17924 | 297 | 263 | 2702 | 302 | 4070 | 29889 | 13 | 29930 | 450 | 817 | 304 | 8569 | 373 | 263 | 2702 | 302 | 4070 | 322 | 4953 | 278 | 748 | 29899 | 517 | 17924 | 297 | 393 | 4038 | 29889 | 13 | 29930 | 450 | 13500 | 310 | 4969 | 21114 | 2793 | 322 | 5214 | 263 | 7881 | 2820 | 372 | 29889 | 13 source | 13 | 13 | 29930 | 450 | 13500 | 310 | 4969 | 263 | 7333 | 14982 | 322 | 10127 | 292 | 6743 | 761 | 408 | 385 | 17924 | 297 | 263 | 2702 | 302 | 4070 | 29889 | 13 | 29930 | 450 | 817 | 304 | 8569 | 373 | 16330 | 697 | 29915 | 29879 | 20026 | 322 | 13138 | 995 | 1549 | 2793 | 9999 | 292 | 322 | 5264 | 5745 | 29889 | 13 | 29930 | 450 | 26002 | 310 | 5214 | 263 | 7881 | 322 | 9926 | 3241 | 21702 | 411 | 697 | 29915 match | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0@nv-guomingz This case is convenient for debugging because diff start=4 prompt="How can I make InitResponse available in an url in the code you showed above"
index | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- medusa | 29973 | 13 | 13 | 29902 | 864 | 304 | 671 | 278 | 10886 | 5103 | 297 | 385 | 3142 | 763 | 445 | 29901 | 13 | 13 | 29966 | 29874 | 2822 | 543 | 2344 | 29918 | 5327 | 29889 | 1420 | 1013 | 6644 | 6646 | 829 | 29874 | 29958 | 13 | 13 | 392 | 769 | 671 | 278 | 10886 | 5103 | 297 | 278 | 4544 | 934 | 29889 | 13 | 13 | 22550 | 29901 | 887 | 508 | 671 | 278 | 421 | 690 | 29889 | 9482 | 29952 | 740 | 304 | 4050 | 278 | 421 | 6644 | 5103 | 29952 | 1776 | 411 | 278 | 421 | 6644 | 5103 | 29952 | 1203 | 408 | 263 | 3030 | 2286 | 29889 source | 29973 | 13 | 13 | 22550 | 29901 | 887 | 508 | 671 | 278 | 421 | 2344 | 5103 | 29952 | 1203 | 297 | 278 | 3988 | 491 | 7797 | 5281 | 372 | 304 | 4663 | 322 | 769 | 9348 | 372 | 408 | 263 | 2346 | 3443 | 297 | 278 | 3988 | 29889 | 2266 | 29915 | 29879 | 385 | 1342 | 310 | 920 | 366 | 508 | 6623 | 278 | 421 | 2344 | 5103 | 29952 | 740 | 304 | 6176 | 445 | 29901 | 13 | 28956 | 7729 | 13 | 2220 | 2069 | 5103 | 29898 | 5327 | 29897 | 426 | 13 | 29871 | 1040 | 848 | 353 | 2933 | 29889 | 1272 | 29936 | 13 | 29871 | 1040 | 2069 | 5103 match | 1 | 1 | 1 | 0 | 0 | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |Got it. Just wanna double check the trt-llm version u're using now?
@nv-guomingz [TensorRT-LLM] TensorRT-LLM version: 0.11.0.dev2024062500
System Info
Hardware: A10 cuda: 12.4 Driver: 535.54.03 Tensorrt-llm: 0.10.0 medusa base model: vicuna-7b-v1.3 medusa head model: FasterDecoding/medusa-vicuna-7b-v1.3
Who can help?
@kaiyux @by
`
convert medusa engine
python convert_checkpoint.py --model_dir /data/Medusa/vicuna-7b-v1.3 \ --medusa_model_dir /data/Medusa/lm_head \ --output_dir ./tllm_checkpoint_1gpu_medusa \ --dtype float16 \ --fixed_num_medusa_heads 4
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_medusa \ --output_dir ./vicuna-medusa \ --gemm_plugin float16 \ --speculative_decoding_mode medusa \ --max_batch_size 8 `
`
convert source model
python ../llama/convert_checkpoint.py --model_dir /data/Medusa/vicuna-7b-v1.3 \ --output_dir ./tllm_checkpoint_1gpu_medusa \ --dtype float16
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_medusa \ --output_dir ./vicuna-source \ --gemm_plugin float16 \ --max_batch_size 8
` There is a diff when index=31
index | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- medusa | 29992 | 1552 | 29883 | 12896 | 3626 | 29886 | 29889 | 510 | 13 | 13 | 1576 | 14062 | 323 | 6472 | 338 | 263 | 13436 | 5745 | 5001 | 393 | 4511 | 29879 | 9850 | 414 | 411 | 278 | 3186 | 30010 | 1 | 450 | 29871 | 29896 | 29929 | 29955 | 29945 | 6507 | 716 | 4823 | 525 | 6816 | 669 | 887 | 323 | 12966 | 838 | 650 | 29915 | 13 | 1576 | 29871 | 29896 | 29929 | 29955 | 29945 | 505 | 5492 | 263 | 716 | 4823 | 2000 | 525 | 6816 | 669 | 887 | 323 | 12966 | 838 | 650 | 4286 | 13 | 1576 | 5702 | 338 | 278 | 9281 | 304 | 367 | 26239 | 515 | 278 source | 29992 | 1552 | 29883 | 12896 | 3626 | 29886 | 29889 | 510 | 13 | 13 | 1576 | 14062 | 323 | 6472 | 338 | 263 | 13436 | 5745 | 5001 | 393 | 4511 | 29879 | 9850 | 414 | 411 | 278 | 3186 | 30010 | 1 | 450 | 937 | 931 | 306 | 4446 | 278 | 14064 | 376 | 1576 | 10213 | 4634 | 310 | 10705 | 24532 | 29891 | 1699 | 306 | 471 | 15469 | 491 | 278 | 325 | 3640 | 322 | 6382 | 262 | 1230 | 3186 | 393 | 278 | 15572 | 391 | 29892 | 10705 | 24532 | 29891 | 29892 | 14117 | 1573 | 29889 | 450 | 14064 | 29892 | 10624 | 491 | 322 | 380 | 23693 | 4111 | 624 | 5495`
medusa and source engine decode sample
import ast import tensorrt_llm from tensorrt_llm.runtime import ModelRunnerCpp import torch import numpy from transformers import AutoTokenizer
engine_dir="source" tokenizer = AutoTokenizer.from_pretrained("/data/Medusa/vicuna-7b-v1.3") medusa_choices=None temperature=1.0 top_k=1 top_p=1.0 num_beams=1 max_new_tokens=80
prompt="hello" pad_id = tokenizer.pad_token_id end_id = tokenizer.eos_token_id
runtime_rank = tensorrt_llm.mpi_rank()
runner_kwargs = dict(engine_dir=engine_dir, lora_dir=None, rank=runtime_rank, debug_mode=True, lora_ckpt_source="hf", gpu_weights_percent=1)
if medusa_choices is not None: medusa_choices = ast.literal_eval(medusa_choices) assert temperature == 1.0, "Medusa should use temperature == 1.0" assert num_beams == 1, "Medusa should use num_beams == 1" runner_kwargs.update(medusa_choices=medusa_choices)
input_tensor=tokenizer.encode(prompt) batch_input_ids = [ torch.tensor(input_tensor, dtype=torch.int32) ] input_lengths = [x.size(0) for x in batch_input_ids] if True: runner_kwargs.update( max_batch_size=1, max_input_len=max(input_lengths), max_output_len=max_new_tokens, max_beam_width=num_beams, max_attention_window_size=None, sink_token_length=None ) runner = ModelRunnerCpp.from_dir(**runner_kwargs)
with torch.no_grad(): outputs = runner.generate( batch_input_ids, max_new_tokens=max_new_tokens, max_attention_window_size=None, sink_token_length=None, end_id=end_id, pad_id=pad_id, temperature=temperature, top_k=top_k, top_p=top_p, num_beams=num_beams, length_penalty=1.0, early_stopping=1, repetition_penalty=1.0, presence_penalty=0.0, frequency_penalty=0.0, stop_words_list=None, bad_words_list=None, output_cum_log_probs=False, output_log_probs=False, lora_uids=None, prompt_table=None, prompt_tasks=None, streaming=False, output_sequence_lengths=True, return_dict=True, medusa_choices=medusa_choices) torch.cuda.synchronize()
out_ids=outputs["output_ids"] out_ids=out_ids.tolist()[0][0][input_lengths[0]:]
print(out_ids) `
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
The output token should be aligned with the source engine, and speculative decoding should not introduce accuracy loss
actual behavior
There is a diff between medusa and the output of the source model
additional notes
Through debugging gptManagerBenchmark, it is possible that diff may occur in the acceptDraftTokensByIdsWithPaths of Tensorrtllm/cpp/Tensorrtllm/kernels/decodedKernels.cu. When accepting the results of all medusa headers, the next token may be predicted incorrectly