Open rangehow opened 1 month ago
Hi,
I am also seeing different results for the same prompt even though temperature is set to 0. Complete sampling parameter is:
SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=0, top_
k=1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=128, min_tokens=0, logprobs=1, prompt_logp
robs=1, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None)
Version was just updated to v0.4.3.
I'm investigating the issue. Verified bug by running examples/offline_inference.py with:
sampling_params = SamplingParams(temperature=0.0, max_tokens=10)
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
However, bug is present only when adding/removing prompts to/from the input batch. Same behavior is seen across older versions (v0.3.3, v0.4.2, v0.4.3).
Selected output for reference:
Prompt from first batch: 'Hello, my name is', Generated text: " and I'm writing you today to learn more about"
Prompt from first batch: 'The capital of France is', Generated text: ' Paris, which is located in the north of the'
VS
Prompt from second batch: 'Hello, my name is', Generated text: " and I'm writing you today to learn more about"
Prompt from second batch: 'The capital of France is', Generated text: ' Paris. It is located in the north of the'
Prompt from second batch: 'The future of AI is', Generated text: ' here, and itβs already changing the way we'
first_sampling_result = [
[([323], [0]), ([12366], [0])],
[([358], [0]), ([11], [0])],
[([2846], [0]), ([902], [0])],
[([4477], [0]), ([374], [0])],
[([499], [0]), ([7559], [0])],
[([3432], [0]), ([304], [0])],
[([311], [0]), ([279], [0])],
[([4048], [0]), ([10411], [0])],
[([810], [0]), ([315], [0])],
[([922], [0]), ([279], [0])]
]
second_sampling_resuilt = [
[([323], [0]), ([12366], [0]), ([1618], [0])],
[([358], [0]), ([13], [0]), ([11], [0])],
[([2846], [0]), ([1102], [0]), ([323], [0])],
[([4477], [0]), ([374], [0]), ([433], [0])],
[([499], [0]), ([7559], [0]), ([753], [0])],
[([3432], [0]), ([304], [0]), ([2736], [0])],
[([311], [0]), ([279], [0]), ([10223], [0])],
[([4048], [0]), ([10411], [0]), ([279], [0])],
[([810], [0]), ([315], [0]), ([1648], [0])],
[([922], [0]), ([279], [0]), ([584], [0])]
]
All fields looked as expected when I stepped into the Sampler
code and examined sampling_metadata
. Will further investigate model output before the sampling stage.
# === Sampling Metadata when generating second output token in previous example ===
SamplingMetadata(seq_groups=[
SequenceGroupToSample(seq_ids=[0], sampling_params=SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[128001], include_stop_str_in_output=False, ignore_eos=False, max_tokens=10, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), seq_data={0: SequenceData(prompt_token_ids=[128000, 9906, 11, 856, 836, 374], output_token_ids=[323], cumulative_logprob=-4.27239990234375)}, seq_len=None, query_len=None, generator=None, is_prompt=False, prompt_logprob_indices=[], sample_indices=[0]),
SequenceGroupToSample(seq_ids=[1], sampling_params=SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[128001], include_stop_str_in_output=False, ignore_eos=False, max_tokens=10, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), seq_data={1: SequenceData(prompt_token_ids=[128000, 791, 6864, 315, 9822, 374], output_token_ids=[12366], cumulative_logprob=-1.4869756698608398)}, seq_len=None, query_len=None, generator=None, is_prompt=False, prompt_logprob_indices=[], sample_indices=[1])], selected_token_indices=tensor([0, 1], device='cuda:0'), categorized_sample_indices={<SamplingType.GREEDY: 0>: tensor([[0, 0],
[1, 1]], device='cuda:0', dtype=torch.int32), <SamplingType.RANDOM: 1>: tensor([], device='cuda:0', size=(0, 2), dtype=torch.int32), <SamplingType.RANDOM_SEED: 2>: tensor([], device='cuda:0', size=(0, 2), dtype=torch.int32), <SamplingType.BEAM: 3>: tensor([], device='cuda:0', size=(0, 2), dtype=torch.int32)}),
Sampling results:
[([358], [0]), ([11], [0])]
SamplingMetadata(seq_groups=[
SequenceGroupToSample(seq_ids=[2], sampling_params=SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[128001], include_stop_str_in_output=False, ignore_eos=False, max_tokens=10, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), seq_data={2: SequenceData(prompt_token_ids=[128000, 9906, 11, 856, 836, 374], output_token_ids=[323], cumulative_logprob=-4.276355743408203)}, seq_len=None, query_len=None, generator=None, is_prompt=False, prompt_logprob_indices=[], sample_indices=[0]),
SequenceGroupToSample(seq_ids=[3], sampling_params=SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[128001], include_stop_str_in_output=False, ignore_eos=False, max_tokens=10, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), seq_data={3: SequenceData(prompt_token_ids=[128000, 791, 6864, 315, 9822, 374], output_token_ids=[12366], cumulative_logprob=-1.4816458225250244)}, seq_len=None, query_len=None, generator=None, is_prompt=False, prompt_logprob_indices=[], sample_indices=[1]),
SequenceGroupToSample(seq_ids=[4], sampling_params=SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[128001], include_stop_str_in_output=False, ignore_eos=False, max_tokens=10, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), seq_data={4: SequenceData(prompt_token_ids=[128000, 791, 3938, 315, 15592, 374], output_token_ids=[1618], cumulative_logprob=-2.299220085144043)}, seq_len=None, query_len=None, generator=None, is_prompt=False, prompt_logprob_indices=[], sample_indices=[2])], selected_token_indices=tensor([0, 1, 2], device='cuda:0'), categorized_sample_indices={<SamplingType.GREEDY: 0>: tensor([[0, 0],
[1, 1],
[2, 2]], device='cuda:0', dtype=torch.int32), <SamplingType.RANDOM: 1>: tensor([], device='cuda:0', size=(0, 2), dtype=torch.int32), <SamplingType.RANDOM_SEED: 2>: tensor([], device='cuda:0', size=(0, 2), dtype=torch.int32), <SamplingType.BEAM: 3>: tensor([], device='cuda:0', size=(0, 2), dtype=torch.int32)}),
Sampling results:
[([358], [0]), ([13], [0]), ([11], [0])]
Is there any script that I can use to reproduce this issue?
I've been looking into #5607 which appears related, but after some digging it, that bug seems to related to the presence of repetition_penalty
on some requests but not others. That doesn't seem to be the case here.
I think #5607 fixed a different issue. After comparing logits before and after temperature scaling, I realized the zero-temperature is erroneously reassigned to 1.0. It should be temperature = _SAMPLING_EPS
instead.
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/sampling_metadata.py#L359-L363
I think #5607 fixed a different issue. After comparing logits before and after temperature scaling, I realized the zero-temperature is erroneously reassigned to 1.0. It should be
temperature = _SAMPLING_EPS
instead.https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/sampling_metadata.py#L359-L363
I think these lines of code are likely to be related to the problem, but whether the temperature should be set to _SAMPLING_EPS remains to be sorted out. I quickly tested this modification and found that the decoding result turned into nonsense output, unfortunately.
π Describe the bug
When using different generation configurations, such as top_k=1 or temperature=0 (while keeping other settings unchanged), why do the generated results change? They should both correspond to a deterministic greedy decoding. vllm 0.4.3
Supplement:
The main issue encountered here is that the results generated by setting the temperature coefficient to 0 or topk to 1 are different. I understand that due to operator optimization and the lack of conventional arithmetic properties in floating-point numbers, matrix operations have a certain randomness. However, the sampling process occurs after the hidden_state is generated, at which point no calculations are involved. Therefore, the sampling results of the two sampling parameters should be the same.
Hello, @rangehow may I ask about which model you are using to produce this bug?
Lately, I encountered the same inconsistent behavior when setting top_k=1 (or temperature=0) for a GPTQ quantized model. I dug into the intermediate outputs and found that there is nothing to do with the sampling_metadata
, but the hidden_state
. The hidden_state
inputs for the logits_procesor
have already been slightly different for identical prompts.
Yet I am not able to reproduce this bug when I am using a non-quantized fp16 (bf16) model.
π Describe the bug
When using different generation configurations, such as top_k=1 or temperature=0 (while keeping other settings unchanged), why do the generated results change? They should both correspond to a deterministic greedy decoding. vllm 0.4.3
Supplement:
The main issue encountered here is that the results generated by setting the temperature coefficient to 0 or topk to 1 are different. I understand that due to operator optimization and the lack of conventional arithmetic properties in floating-point numbers, matrix operations have a certain randomness. However, the sampling process occurs after the hidden_state is generated, at which point no calculations are involved. Therefore, the sampling results of the two sampling parameters should be the same.
Hello, @rangehow may I ask about which model you are using to produce this bug? Lately, I encountered the same inconsistent behavior when setting top_k=1 (or temperature=0) for a GPTQ quantized model. I dug into the intermediate outputs and found that there is nothing to do with the
sampling_metadata
, but thehidden_state
. Thehidden_state
inputs for thelogits_procesor
have already been slightly different for identical prompts.Yet I am not able to reproduce this bug when I am using a non-quantized fp16 (bf16) model.
gemma-2b π
π Describe the bug
When using different generation configurations, such as top_k=1 or temperature=0 (while keeping other settings unchanged), why do the generated results change? They should both correspond to a deterministic greedy decoding. vllm 0.4.3
Supplement:
The main issue encountered here is that the results generated by setting the temperature coefficient to 0 or topk to 1 are different. I understand that due to operator optimization and the lack of conventional arithmetic properties in floating-point numbers, matrix operations have a certain randomness. However, the sampling process occurs after the hidden_state is generated, at which point no calculations are involved. Therefore, the sampling results of the two sampling parameters should be the same.