SGLang is a structured generation language designed for large language models (LLMs). It makes your interaction with models faster and more controllable.
Apache License 2.0
2.75k
stars
177
forks
source link
The `choices` normalised logprobs calculation returns poor results due to bias for longer-token options #523
I've noticed that the gen(choices=[...]) functionality sometimes performs poorly, even for simple tasks. This is due to a flawed normalised logprobs calculation. The calculation biases options that comprise more tokens, where the latter tokens are highly predictable given the prior tokens.
Reproducible Example
This is most easily seen in choices with token overlap, so I've constructed a contrived example that illustrates this. The outputs are generated with llama 3 8B, which should breeze through this task under normal circumstances.
import sglang as sgl
import textwrap
# Define answer choices with overlapping substrings and tokenised forms
# assumes llama 3 8B tokeniser
choices_and_tokenised_forms = [
("organ", ["organ"]),
("organism", ["organ", "ism"]),
("organisation", ["organisation"]),
("organelle", ["org", "ane", "lle"]),
("organometallic", ["organ", "omet", "al", "lic"]),
]
choices = [c for c, _ in choices_and_tokenised_forms]
# Define the categorisation question
template = "What category does '{input}' belong to? {choices}"
# Generate the (optional) system prompt with few-shot examples
sys_prompt = ""
for example in [
("ribosome", "organelle"),
("liver", "organ"),
("Google", "organisation"),
("ferrocene", "organometallic"),
("human", "organism"),
]:
sys_prompt += "user:" + template.format(input=example[0], choices=choices)
sys_prompt += f"\nassisant:{example[1]}\n\n"
@sgl.function
def run(s, input: str, show_few_shot_examples: bool = False):
if show_few_shot_examples:
s += sgl.system(f"You categorise things.\n\n ##Examples\n{sys_prompt}")
s += sgl.user(template.format(input=input, choices=choices, temperature=0))
s += sgl.assistant(sgl.gen("answer", choices=choices))
def format_results(state, input):
answer = f" '{input}' categorised as: '{state['answer']}'"
meta = state.get_meta_info("answer")
out = f"{answer:<50} {'normalised'} {'prefill token logprobs'}"
for i in range(len(meta['normalized_prompt_logprobs'])):
option = f"{choices_and_tokenised_forms[i][0]} ({choices_and_tokenised_forms[i][1]})"
npl = meta['normalized_prompt_logprobs'][i]
ptl = [f"{p[0]:.4f}" for p in meta['prefill_token_logprobs'][i]]
out += f"\n{option:<50} -> {npl:<10.4f} -> {ptl}"
return out
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
for include_examples in [False, True]:
print(f"Show few-shot examples in context = {include_examples}\n")
for input in ["heart", "nucleus", "Microsoft", "mouse", "trimethylboron"]:
state = run(input, show_few_shot_examples=include_examples)
print(textwrap.indent(format_results(state, input), " "))
print()
print("-" * 120)
The second set of results yields the expected categorisations.
Explanation
We see that only 1/5 answers are correct in the first set of results. Not coincidentally, the only correctly answered question ('trimethylboron' categorised as: 'organometallic') is the one where the correct answer has the most tokens.
The first prefill token, common across all options, is ":". I'm not actually sure why this is present — something to do with token healing? Is this coming from the assistant: role prefix? Regardless, it's not important as it's consistent across all options, and it's not responsible for the poor performance (although does skew the logprobs calculations in unpredictable ways).
Inspecting the prefill token logprobs for the "organometallic" responses is instructive. Even if the ["organ", "omet"] tokens are relatively disfavoured, the ["al", "lic"] tokens are essentially guaranteed once you have the "organomet-" substring. The normalised logprobs calculation is a simple average of the prefill token logprobs, which means the ["al", "lic"] tokens massively inflate the score, even if "organometallic" is obviously wrong given the prior context.
The second set of results — which provides in-context few-shot examples — does rectify this with 5/5 correct answers. It seems that showing the model expected outputs leads to tokens beyond "organ", such as "omet" , being sufficiently penalised to avoid the problem. It is surprising that the model requires this level of priming for such a simple task, however (even without the few-shot examples, the model is told the permitted options in the user prompt).
Other Observations
Prefixing the the assistant response with "Answer: " doesn't help, but does result in prefill tokens that only correspond to the choices and nothing else (i.e. no ":" prior token, or similar). Why? The inconsistent presence/absence of prior tokens skews the scores and can lead to erratic selection behaviour when small tweaks are made to the prompt prefixes.
I tried running this example using regex instead (i.e. gen(regex="(" + "|".join(choices) + ")")), thinking this would resolve the issue with simple greedy token selection. But this also performs poorly (and extremely unpredictably, without temperature=0).
I've also explored avoiding overlapping options by wrapping each option in double quotes, but this doesn't solve the problem.
Suggestions
I think this is a severe enough flaw in the normalised logprobs calculation to be considered a bug. The outputs I've observed in several real-world settings are also unreasonably poor for simple tasks and capable models. I think evaluating all the options in their entirety is a good approach in theory, but a more sophisticated normalised logprobs calculation is required to adjust for bias towards options with more tokens.
Offering an alternative, greedy token selection choices decoding option could help. That said, I'm not sure why I still get poor outputs when I attempt to simulate this via gen(regex=...).
Problem
I've noticed that the
gen(choices=[...])
functionality sometimes performs poorly, even for simple tasks. This is due to a flawed normalised logprobs calculation. The calculation biases options that comprise more tokens, where the latter tokens are highly predictable given the prior tokens.Reproducible Example
This is most easily seen in choices with token overlap, so I've constructed a contrived example that illustrates this. The outputs are generated with llama 3 8B, which should breeze through this task under normal circumstances.
Outputs:
The second set of results yields the expected categorisations.
Explanation
We see that only 1/5 answers are correct in the first set of results. Not coincidentally, the only correctly answered question (
'trimethylboron' categorised as: 'organometallic'
) is the one where the correct answer has the most tokens.The first prefill token, common across all options, is
":"
. I'm not actually sure why this is present — something to do with token healing? Is this coming from theassistant:
role prefix? Regardless, it's not important as it's consistent across all options, and it's not responsible for the poor performance (although does skew the logprobs calculations in unpredictable ways).Inspecting the prefill token logprobs for the "organometallic" responses is instructive. Even if the
["organ", "omet"]
tokens are relatively disfavoured, the["al", "lic"]
tokens are essentially guaranteed once you have the"organomet-"
substring. The normalised logprobs calculation is a simple average of the prefill token logprobs, which means the["al", "lic"]
tokens massively inflate the score, even if "organometallic" is obviously wrong given the prior context.The second set of results — which provides in-context few-shot examples — does rectify this with 5/5 correct answers. It seems that showing the model expected outputs leads to tokens beyond
"organ"
, such as"omet"
, being sufficiently penalised to avoid the problem. It is surprising that the model requires this level of priming for such a simple task, however (even without the few-shot examples, the model is told the permitted options in the user prompt).Other Observations
"Answer: "
doesn't help, but does result in prefill tokens that only correspond to the choices and nothing else (i.e. no":"
prior token, or similar). Why? The inconsistent presence/absence of prior tokens skews the scores and can lead to erratic selection behaviour when small tweaks are made to the prompt prefixes.gen(regex="(" + "|".join(choices) + ")")
), thinking this would resolve the issue with simple greedy token selection. But this also performs poorly (and extremely unpredictably, withouttemperature=0
).Suggestions
choices
decoding option could help. That said, I'm not sure why I still get poor outputs when I attempt to simulate this viagen(regex=...)
.