Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The choices normalised logprobs calculation returns poor results due to bias for longer-token options #523

Open
AidanCooper opened this issue Jun 10, 2024 · 4 comments

Comments

@AidanCooper
Copy link

AidanCooper commented Jun 10, 2024

Problem

I've noticed that the gen(choices=[...]) functionality sometimes performs poorly, even for simple tasks. This is due to a flawed normalised logprobs calculation. The calculation biases options that comprise more tokens, where the latter tokens are highly predictable given the prior tokens.

Reproducible Example

This is most easily seen in choices with token overlap, so I've constructed a contrived example that illustrates this. The outputs are generated with llama 3 8B instruct, which should breeze through this task under normal circumstances.

import sglang as sgl
import textwrap

# Define answer choices with overlapping substrings and tokenised forms
# assumes llama 3 8B tokeniser
choices_and_tokenised_forms = [
    ("organ", ["organ"]),
    ("organism", ["organ", "ism"]),
    ("organisation", ["organisation"]),
    ("organelle", ["org", "ane", "lle"]),
    ("organometallic", ["organ", "omet", "al", "lic"]),
]
choices = [c for c, _ in choices_and_tokenised_forms]


# Define the categorisation question
template = "What category does '{input}' belong to? {choices}"

# Generate the (optional) system prompt with few-shot examples
sys_prompt = ""
for example in [
    ("ribosome", "organelle"),
    ("liver", "organ"),
    ("Google", "organisation"),
    ("ferrocene", "organometallic"),
    ("human", "organism"),
]:
    sys_prompt += "user:" + template.format(input=example[0], choices=choices)
    sys_prompt += f"\nassisant:{example[1]}\n\n"


@sgl.function
def run(s, input: str, show_few_shot_examples: bool = False):
    if show_few_shot_examples:
        s += sgl.system(f"You categorise things.\n\n ##Examples\n{sys_prompt}")
    s += sgl.user(template.format(input=input, choices=choices, temperature=0))
    s += sgl.assistant(sgl.gen("answer", choices=choices))


def format_results(state, input):
    answer = f"  '{input}' categorised as: '{state['answer']}'"
    meta = state.get_meta_info("answer")
    out = f"{answer:<50}    {'normalised'}    {'prefill token logprobs'}"
    for i in range(len(meta['normalized_prompt_logprobs'])):
        option = f"{choices_and_tokenised_forms[i][0]} ({choices_and_tokenised_forms[i][1]})"
        npl = meta['normalized_prompt_logprobs'][i]
        ptl = [f"{p[0]:.4f}" for p in meta['prefill_token_logprobs'][i]]
        out += f"\n{option:<50} -> {npl:<10.4f} -> {ptl}"
    return out


sgl.set_default_backend(sgl.RuntimeEndpoint("http:https://localhost:30000"))

for include_examples in [False, True]:
    print(f"Show few-shot examples in context = {include_examples}\n")
    for input in ["heart", "nucleus", "Microsoft", "mouse", "trimethylboron"]:
        state = run(input, show_few_shot_examples=include_examples)
        print(textwrap.indent(format_results(state, input), "    "))
        print()
    print("-" * 120)

Outputs:

Show few-shot examples in context = False

      'heart' categorised as: 'organelle'                 normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.6190    -> ['-0.1265', '-3.1116']
    organism (['organ', 'ism'])                        -> -1.7443    -> ['-0.1265', '-3.1116', '-1.9949']
    organisation (['organisation'])                    -> -3.8885    -> ['-0.1265', '-7.6506']
    organelle (['org', 'ane', 'lle'])                  -> -1.3777    -> ['-0.1265', '-5.3772', '-0.0048', '-0.0023']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.3915    -> ['-0.1265', '-3.1116', '-3.7136', '-0.0034', '-0.0023']

      'nucleus' categorised as: 'organometallic'          normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.8324    -> ['-0.2145', '-3.4502']
    organism (['organ', 'ism'])                        -> -1.8675    -> ['-0.2145', '-3.4502', '-1.9378']
    organisation (['organisation'])                    -> -3.1800    -> ['-0.2145', '-6.1456']
    organelle (['org', 'ane', 'lle'])                  -> -1.1103    -> ['-0.2145', '-4.2237', '-0.0013', '-0.0017']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.0997    -> ['-0.2145', '-3.4502', '-1.8284', '-0.0029', '-0.0022']

      'Microsoft' categorised as: 'organometallic'        normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.5901    -> ['-0.1446', '-3.0355']
    organism (['organ', 'ism'])                        -> -1.6397    -> ['-0.1446', '-3.0355', '-1.7391']
    organisation (['organisation'])                    -> -2.9416    -> ['-0.1446', '-5.7387']
    organelle (['org', 'ane', 'lle'])                  -> -1.4376    -> ['-0.1446', '-5.5746', '-0.0283', '-0.0029']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.1792    -> ['-0.1446', '-3.0355', '-2.7079', '-0.0052', '-0.0028']

      'mouse' categorised as: 'organelle'                 normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.7110    -> ['-0.1392', '-3.2829']
    organism (['organ', 'ism'])                        -> -1.5566    -> ['-0.1392', '-3.2829', '-1.2477']
    organisation (['organisation'])                    -> -3.9181    -> ['-0.1392', '-7.6969']
    organelle (['org', 'ane', 'lle'])                  -> -1.3491    -> ['-0.1392', '-5.2516', '-0.0041', '-0.0015']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.4992    -> ['-0.1392', '-3.2829', '-4.0680', '-0.0033', '-0.0028']

      'trimethylboron' categorised as: 'organometallic'    normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.4093    -> ['-0.1379', '-2.6806']
    organism (['organ', 'ism'])                        -> -2.7661    -> ['-0.1379', '-2.6806', '-5.4796']
    organisation (['organisation'])                    -> -3.9659    -> ['-0.1379', '-7.7939']
    organelle (['org', 'ane', 'lle'])                  -> -1.3317    -> ['-0.1379', '-5.1338', '-0.0527', '-0.0023']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -0.5933    -> ['-0.1379', '-2.6806', '-0.1436', '-0.0034', '-0.0008']

------------------------------------------------------------------------------------------------------------------------
Show few-shot examples in context = True

      'heart' categorised as: 'organ'                     normalised    prefill token logprobs
    organ (['organ'])                                  -> -0.2509    -> ['-0.0799', '-0.4219']
    organism (['organ', 'ism'])                        -> -2.0750    -> ['-0.0799', '-0.4219', '-5.7232']
    organisation (['organisation'])                    -> -3.7431    -> ['-0.0799', '-7.4063']
    organelle (['org', 'ane', 'lle'])                  -> -0.9032    -> ['-0.0799', '-3.5000', '-0.0298', '-0.0031']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.7599    -> ['-0.0799', '-0.4219', '-8.2857', '-0.0087', '-0.0034']

      'nucleus' categorised as: 'organelle'               normalised    prefill token logprobs
    organ (['organ'])                                  -> -1.7653    -> ['-0.1489', '-3.3817']
    organism (['organ', 'ism'])                        -> -1.8995    -> ['-0.1489', '-3.3817', '-2.1678']
    organisation (['organisation'])                    -> -3.7379    -> ['-0.1489', '-7.3270']
    organelle (['org', 'ane', 'lle'])                  -> -0.0921    -> ['-0.1489', '-0.2176', '-0.0006', '-0.0011']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -1.9658    -> ['-0.1489', '-3.3817', '-6.2928', '-0.0040', '-0.0017']

      'Microsoft' categorised as: 'organisation'          normalised    prefill token logprobs
    organ (['organ'])                                  -> -0.8883    -> ['-0.1198', '-1.6569']
    organism (['organ', 'ism'])                        -> -1.1325    -> ['-0.1198', '-1.6569', '-1.6208']
    organisation (['organisation'])                    -> -0.6383    -> ['-0.1198', '-1.1569']
    organelle (['org', 'ane', 'lle'])                  -> -1.2105    -> ['-0.1198', '-4.5866', '-0.1336', '-0.0021']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -0.7088    -> ['-0.1198', '-1.6569', '-1.7615', '-0.0043', '-0.0017']

      'mouse' categorised as: 'organism'                  normalised    prefill token logprobs
    organ (['organ'])                                  -> -0.1719    -> ['-0.1273', '-0.2166']
    organism (['organ', 'ism'])                        -> -0.1188    -> ['-0.1273', '-0.2166', '-0.0127']
    organisation (['organisation'])                    -> -2.9610    -> ['-0.1273', '-5.7947']
    organelle (['org', 'ane', 'lle'])                  -> -1.0844    -> ['-0.1273', '-3.9744', '-0.2330', '-0.0030']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -2.2812    -> ['-0.1273', '-0.2166', '-11.0517', '-0.0086', '-0.0020']

      'trimethylboron' categorised as: 'organometallic'    normalised    prefill token logprobs
    organ (['organ'])                                  -> -0.3231    -> ['-0.0992', '-0.5471']
    organism (['organ', 'ism'])                        -> -3.2023    -> ['-0.0992', '-0.5471', '-8.9607']
    organisation (['organisation'])                    -> -3.1551    -> ['-0.0992', '-6.2111']
    organelle (['org', 'ane', 'lle'])                  -> -0.7889    -> ['-0.0992', '-2.9299', '-0.1246', '-0.0018']
    organometallic (['organ', 'omet', 'al', 'lic'])    -> -0.1314    -> ['-0.0992', '-0.5471', '-0.0076', '-0.0025', '-0.0007']

------------------------------------------------------------------------------------------------------------------------

The second set of results yields the expected categorisations.

Explanation

We see that only 1/5 answers are correct in the first set of results. Not coincidentally, the only correctly answered question ('trimethylboron' categorised as: 'organometallic') is the one where the correct answer has the most tokens.

The first prefill token, common across all options, is ":". I'm not actually sure why this is present — something to do with token healing? Is this coming from the assistant: role prefix? Regardless, it's not important as it's consistent across all options, and it's not responsible for the poor performance (although does skew the logprobs calculations in unpredictable ways).

Inspecting the prefill token logprobs for the "organometallic" responses is instructive. Even if the ["organ", "omet"] tokens are relatively disfavoured, the ["al", "lic"] tokens are essentially guaranteed once you have the "organomet-" substring. The normalised logprobs calculation is a simple average of the prefill token logprobs, which means the ["al", "lic"] tokens massively inflate the score, even if "organometallic" is obviously wrong given the prior context.

The second set of results — which provides in-context few-shot examples — does rectify this with 5/5 correct answers. It seems that showing the model expected outputs leads to tokens beyond "organ", such as "omet" , being sufficiently penalised to avoid the problem. It is surprising that the model requires this level of priming for such a simple task, however (even without the few-shot examples, the model is told the permitted options in the user prompt).

Other Observations

  • Prefixing the the assistant response with "Answer: " doesn't help, but does result in prefill tokens that only correspond to the choices and nothing else (i.e. no ":" prior token, or similar). Why? The inconsistent presence/absence of prior tokens skews the scores and can lead to erratic selection behaviour when small tweaks are made to the prompt prefixes.
  • I tried running this example using regex instead (i.e. gen(regex="(" + "|".join(choices) + ")")), thinking this would resolve the issue with simple greedy token selection. But this also performs poorly (and extremely unpredictably, without temperature=0).
  • I've also explored avoiding overlapping options by wrapping each option in double quotes, but this doesn't solve the problem.

Suggestions

  • I think this is a severe enough flaw in the normalised logprobs calculation to be considered a bug. The outputs I've observed in several real-world settings are also unreasonably poor for simple tasks and capable models. I think evaluating all the options in their entirety is a good approach in theory, but a more sophisticated normalised logprobs calculation is required to adjust for bias towards options with more tokens.
  • Offering an alternative, greedy token selection choices decoding option could help. That said, I'm not sure why I still get poor outputs when I attempt to simulate this via gen(regex=...).
@AidanCooper
Copy link
Author

Tagging @merrymercy based on git history of RuntimeEndpoint.select. Do you have any thoughts on this?

@merrymercy
Copy link
Contributor

merrymercy commented Jul 9, 2024

Hi @AidanCooper Thanks for the detailed explanation and experiments.

I noticed that you used a base model together with chat template-related primitives (e.g., sgl.user). This creates multiple issues which can be the reason for the poor performance.

  1. The base model is not fine-tuned to follow instructions, so you need few-shot examples to make it work.
  2. sgl.user and sgl.assistant will apply chat templates to the prompts. Because you are using llama-3, it gets translated to things like <|start_header_id|>user<|end_header_id|>. These are special tokens not seen during the training of the base model, so they will lead to strange behavior. Also, in you system prompt, you used another custom chat template style user: ... assistant:.... All these mismatches create issues.
  3. I would suggest you use the instruct-tuned model and correctly format your few-shot example. You can print s.text() to inspect the real prompt. Then redo all of your experiments
  4. Effects like token healing do affect the logits a lot. You can play with the tokenization and white space.

The two points in your suggestions totally make sense.

  1. Do you know any better normalized logprob calculation method?
  2. Yes. regex with temperature=0 can simulate the greedy sampling. Could you describe what is wrong with the regex? cc @hnyls2002

@AidanCooper
Copy link
Author

I noticed that you used a base model together with chat template-related primitives

Nice catch, although that's actually a mistake in my original post — I am using the instruction-tuned variant. I've updated my post with the corrected link.

I would suggest you ... correctly format your few-shot example

Also a fair point, although I don't think this is proving to be an issue in practice. My examples demonstrate that the model returns incorrect answers without the few-shot examples, but correct answers with the few-shot examples (even if the formatting isn't optimal).

  1. Do you know any better normalized logprob calculation method?

No, not that I'm aware of. Is this approach to choices unique to SGLang, or is it common across constrained decoding libraries? Whilst I like the idea of evaluating options in their entirety in principle, I think it might be fundamentally flawed for the reasons demonstrated by my examples, and perhaps greedy sampling is more reliable.

I might play around with this and see if there's a better algorithm for normalising the logprobs — perhaps a heuristic that upweights earlier tokens relative to later tokens in the average calculation. But ultimately, I can't see a straightforward way to distinguish options whose tokens are favoured because of the tokens that precede them (e.g., 'organ', 'omet', 'al', 'lic') from options whose tokens are favoured because they actually represent the best response.

@AidanCooper
Copy link
Author

Yes. regex with temperature=0 can simulate the greedy sampling. Could you describe what is wrong with the regex? cc @hnyls2002

Looking into this again, the results when using regex are better than I originally suggested (not sure why). Reworking the example to use regex:

import sglang as sgl
import textwrap

# Define answer choices with overlapping substrings and tokenised forms
# assumes llama 3 8B tokeniser
choices_and_tokenised_forms = [
    ("organ", ["organ"]),
    ("organism", ["organ", "ism"]),
    ("organisation", ["organisation"]),
    ("organelle", ["org", "ane", "lle"]),
    ("organometallic", ["organ", "omet", "al", "lic"]),
]
choices = [c for c, _ in choices_and_tokenised_forms]


# Define the categorisation question
template = "What category does '{input}' belong to? {choices}"

# Generate the (optional) system prompt with few-shot examples
sys_prompt = ""
for example in [
    ("ribosome", "organelle"),
    ("liver", "organ"),
    ("Google", "organisation"),
    ("ferrocene", "organometallic"),
    ("human", "organism"),
]:
    sys_prompt += "user:" + template.format(input=example[0], choices=choices)
    sys_prompt += f"\nassisant:{example[1]}\n\n"


@sgl.function
def run(s, input: str, show_few_shot_examples: bool = False):
    if show_few_shot_examples:
        s += sgl.system(f"You categorise things.\n\n ##Examples\n{sys_prompt}")
    s += sgl.user(template.format(input=input, choices=choices, temperature=0))
    s += sgl.assistant(sgl.gen("answer", regex="(" + "|".join(choices) + ")", temperature=0))


def format_results(state, input):
    return f"  '{input}' categorised as: '{state['answer']}'"


sgl.set_default_backend(sgl.RuntimeEndpoint("http:https://localhost:30000"))

for include_examples in [False, True]:
    print(f"Show few-shot examples in context = {include_examples}")
    for input in ["heart", "nucleus", "Microsoft", "mouse", "trimethylboron"]:
        state = run(input, show_few_shot_examples=include_examples)
        print(textwrap.indent(format_results(state, input), "    "))
    print("-" * 120)

I get:

Show few-shot examples in context = False
      'heart' categorised as: 'organ'
      'nucleus' categorised as: 'organometallic'
      'Microsoft' categorised as: 'organisation'
      'mouse' categorised as: 'organism'
      'trimethylboron' categorised as: 'organometallic'
------------------------------------------------------------------------------------------------------------------------
Show few-shot examples in context = True
      'heart' categorised as: 'organ'
      'nucleus' categorised as: 'organelle'
      'Microsoft' categorised as: 'organisation'
      'mouse' categorised as: 'organism'
      'trimethylboron' categorised as: 'organometallic'
------------------------------------------------------------------------------------------------------------------------

Without few-shot examples, only one answer is incorrect ('nucleus' categorised as: 'organometallic'). With few-shot examples, all answers are correct. So this is consistent with what we'd expect: greedy sampling via regex yields better results than choices for this example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants