sgl-project / sglang

SGLang is a structured generation language designed for large language models (LLMs). It makes your interaction with models faster and more controllable.
Apache License 2.0
2.75k stars 177 forks source link

The `choices` normalised logprobs calculation returns poor results due to bias for longer-token options #523

Open AidanCooper opened 3 weeks ago

AidanCooper commented 3 weeks ago

Problem

I've noticed that the gen(choices=[...]) functionality sometimes performs poorly, even for simple tasks. This is due to a flawed normalised logprobs calculation. The calculation biases options that comprise more tokens, where the latter tokens are highly predictable given the prior tokens.

Reproducible Example

This is most easily seen in choices with token overlap, so I've constructed a contrived example that illustrates this. The outputs are generated with llama 3 8B, which should breeze through this task under normal circumstances.

import sglang as sgl
import textwrap

# Define answer choices with overlapping substrings and tokenised forms
# assumes llama 3 8B tokeniser
choices_and_tokenised_forms = [
    ("organ", ["organ"]),
    ("organism", ["organ", "ism"]),
    ("organisation", ["organisation"]),
    ("organelle", ["org", "ane", "lle"]),
    ("organometallic", ["organ", "omet", "al", "lic"]),
]
choices = [c for c, _ in choices_and_tokenised_forms]

# Define the categorisation question
template = "What category does '{input}' belong to? {choices}"

# Generate the (optional) system prompt with few-shot examples
sys_prompt = ""
for example in [
    ("ribosome", "organelle"),
    ("liver", "organ"),
    ("Google", "organisation"),
    ("ferrocene", "organometallic"),
    ("human", "organism"),
]:
    sys_prompt += "user:" + template.format(input=example[0], choices=choices)
    sys_prompt += f"\nassisant:{example[1]}\n\n"

@sgl.function
def run(s, input: str, show_few_shot_examples: bool = False):
    if show_few_shot_examples:
        s += sgl.system(f"You categorise things.\n\n ##Examples\n{sys_prompt}")
    s += sgl.user(template.format(input=input, choices=choices, temperature=0))
    s += sgl.assistant(sgl.gen("answer", choices=choices))

def format_results(state, input):
    answer = f"  '{input}' categorised as: '{state['answer']}'"
    meta = state.get_meta_info("answer")
    out = f"{answer:<50}    {'normalised'}    {'prefill token logprobs'}"
    for i in range(len(meta['normalized_prompt_logprobs'])):
        option = f"{choices_and_tokenised_forms[i][0]} ({choices_and_tokenised_forms[i][1]})"
        npl = meta['normalized_prompt_logprobs'][i]
        ptl = [f"{p[0]:.4f}" for p in meta['prefill_token_logprobs'][i]]
        out += f"\n{option:<50} -> {npl:<10.4f} -> {ptl}"
    return out

sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))

for include_examples in [False, True]:
    print(f"Show few-shot examples in context = {include_examples}\n")
    for input in ["heart", "nucleus", "Microsoft", "mouse", "trimethylboron"]:
        state = run(input, show_few_shot_examples=include_examples)
        print(textwrap.indent(format_results(state, input), "    "))
        print()
    print("-" * 120)

Outputs:

Show few-shot examples in context = False

      'heart' categorised as: 'organelle'                 normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.6190    -> ['-0.1265', '-3.1116']
    organism (['organ', 'ism'])                        -> -1.7443    -> ['-0.1265', '-3.1116', '-1.9949']
    organisation (['organisation'])                    -> -3.8885    -> ['-0.1265', '-7.6506']
    organelle (['org', 'ane', 'lle'])                  -> -1.3777    -> ['-0.1265', '-5.3772', '-0.0048', '-0.0023']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.3915    -> ['-0.1265', '-3.1116', '-3.7136', '-0.0034', '-0.0023']

      'nucleus' categorised as: 'organometallic'          normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.8324    -> ['-0.2145', '-3.4502']
    organism (['organ', 'ism'])                        -> -1.8675    -> ['-0.2145', '-3.4502', '-1.9378']
    organisation (['organisation'])                    -> -3.1800    -> ['-0.2145', '-6.1456']
    organelle (['org', 'ane', 'lle'])                  -> -1.1103    -> ['-0.2145', '-4.2237', '-0.0013', '-0.0017']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.0997    -> ['-0.2145', '-3.4502', '-1.8284', '-0.0029', '-0.0022']

      'Microsoft' categorised as: 'organometallic'        normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.5901    -> ['-0.1446', '-3.0355']
    organism (['organ', 'ism'])                        -> -1.6397    -> ['-0.1446', '-3.0355', '-1.7391']
    organisation (['organisation'])                    -> -2.9416    -> ['-0.1446', '-5.7387']
    organelle (['org', 'ane', 'lle'])                  -> -1.4376    -> ['-0.1446', '-5.5746', '-0.0283', '-0.0029']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.1792    -> ['-0.1446', '-3.0355', '-2.7079', '-0.0052', '-0.0028']

      'mouse' categorised as: 'organelle'                 normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.7110    -> ['-0.1392', '-3.2829']
    organism (['organ', 'ism'])                        -> -1.5566    -> ['-0.1392', '-3.2829', '-1.2477']
    organisation (['organisation'])                    -> -3.9181    -> ['-0.1392', '-7.6969']
    organelle (['org', 'ane', 'lle'])                  -> -1.3491    -> ['-0.1392', '-5.2516', '-0.0041', '-0.0015']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.4992    -> ['-0.1392', '-3.2829', '-4.0680', '-0.0033', '-0.0028']

      'trimethylboron' categorised as: 'organometallic'    normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.4093    -> ['-0.1379', '-2.6806']
    organism (['organ', 'ism'])                        -> -2.7661    -> ['-0.1379', '-2.6806', '-5.4796']
    organisation (['organisation'])                    -> -3.9659    -> ['-0.1379', '-7.7939']
    organelle (['org', 'ane', 'lle'])                  -> -1.3317    -> ['-0.1379', '-5.1338', '-0.0527', '-0.0023']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -0.5933    -> ['-0.1379', '-2.6806', '-0.1436', '-0.0034', '-0.0008']

------------------------------------------------------------------------------------------------------------------------
Show few-shot examples in context = True

      'heart' categorised as: 'organ'                     normalised    prefill token logprobs
    organ (['organ'])                                  -> -0.2509    -> ['-0.0799', '-0.4219']
    organism (['organ', 'ism'])                        -> -2.0750    -> ['-0.0799', '-0.4219', '-5.7232']
    organisation (['organisation'])                    -> -3.7431    -> ['-0.0799', '-7.4063']
    organelle (['org', 'ane', 'lle'])                  -> -0.9032    -> ['-0.0799', '-3.5000', '-0.0298', '-0.0031']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.7599    -> ['-0.0799', '-0.4219', '-8.2857', '-0.0087', '-0.0034']

      'nucleus' categorised as: 'organelle'               normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.7653    -> ['-0.1489', '-3.3817']
    organism (['organ', 'ism'])                        -> -1.8995    -> ['-0.1489', '-3.3817', '-2.1678']
    organisation (['organisation'])                    -> -3.7379    -> ['-0.1489', '-7.3270']
    organelle (['org', 'ane', 'lle'])                  -> -0.0921    -> ['-0.1489', '-0.2176', '-0.0006', '-0.0011']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.9658    -> ['-0.1489', '-3.3817', '-6.2928', '-0.0040', '-0.0017']

      'Microsoft' categorised as: 'organisation'          normalised    prefill token logprobs
    organ (['organ'])                                  -> -0.8883    -> ['-0.1198', '-1.6569']
    organism (['organ', 'ism'])                        -> -1.1325    -> ['-0.1198', '-1.6569', '-1.6208']
    organisation (['organisation'])                    -> -0.6383    -> ['-0.1198', '-1.1569']
    organelle (['org', 'ane', 'lle'])                  -> -1.2105    -> ['-0.1198', '-4.5866', '-0.1336', '-0.0021']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -0.7088    -> ['-0.1198', '-1.6569', '-1.7615', '-0.0043', '-0.0017']

      'mouse' categorised as: 'organism'                  normalised    prefill token logprobs
    organ (['organ'])                                  -> -0.1719    -> ['-0.1273', '-0.2166']
    organism (['organ', 'ism'])                        -> -0.1188    -> ['-0.1273', '-0.2166', '-0.0127']
    organisation (['organisation'])                    -> -2.9610    -> ['-0.1273', '-5.7947']
    organelle (['org', 'ane', 'lle'])                  -> -1.0844    -> ['-0.1273', '-3.9744', '-0.2330', '-0.0030']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -2.2812    -> ['-0.1273', '-0.2166', '-11.0517', '-0.0086', '-0.0020']

      'trimethylboron' categorised as: 'organometallic'    normalised    prefill token logprobs
    organ (['organ'])                                  -> -0.3231    -> ['-0.0992', '-0.5471']
    organism (['organ', 'ism'])                        -> -3.2023    -> ['-0.0992', '-0.5471', '-8.9607']
    organisation (['organisation'])                    -> -3.1551    -> ['-0.0992', '-6.2111']
    organelle (['org', 'ane', 'lle'])                  -> -0.7889    -> ['-0.0992', '-2.9299', '-0.1246', '-0.0018']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -0.1314    -> ['-0.0992', '-0.5471', '-0.0076', '-0.0025', '-0.0007']

------------------------------------------------------------------------------------------------------------------------

The second set of results yields the expected categorisations.

Explanation

We see that only 1/5 answers are correct in the first set of results. Not coincidentally, the only correctly answered question ('trimethylboron' categorised as: 'organometallic') is the one where the correct answer has the most tokens.

The first prefill token, common across all options, is ":". I'm not actually sure why this is present — something to do with token healing? Is this coming from the assistant: role prefix? Regardless, it's not important as it's consistent across all options, and it's not responsible for the poor performance (although does skew the logprobs calculations in unpredictable ways).

Inspecting the prefill token logprobs for the "organometallic" responses is instructive. Even if the ["organ", "omet"] tokens are relatively disfavoured, the ["al", "lic"] tokens are essentially guaranteed once you have the "organomet-" substring. The normalised logprobs calculation is a simple average of the prefill token logprobs, which means the ["al", "lic"] tokens massively inflate the score, even if "organometallic" is obviously wrong given the prior context.

The second set of results — which provides in-context few-shot examples — does rectify this with 5/5 correct answers. It seems that showing the model expected outputs leads to tokens beyond "organ", such as "omet" , being sufficiently penalised to avoid the problem. It is surprising that the model requires this level of priming for such a simple task, however (even without the few-shot examples, the model is told the permitted options in the user prompt).

Other Observations

Suggestions