-
Notifications
You must be signed in to change notification settings - Fork 180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The choices
normalised logprobs calculation returns poor results due to bias for longer-token options
#523
Comments
Tagging @merrymercy based on git history of |
Hi @AidanCooper Thanks for the detailed explanation and experiments. I noticed that you used a base model together with chat template-related primitives (e.g.,
The two points in your suggestions totally make sense.
|
Nice catch, although that's actually a mistake in my original post — I am using the instruction-tuned variant. I've updated my post with the corrected link.
Also a fair point, although I don't think this is proving to be an issue in practice. My examples demonstrate that the model returns incorrect answers without the few-shot examples, but correct answers with the few-shot examples (even if the formatting isn't optimal).
No, not that I'm aware of. Is this approach to I might play around with this and see if there's a better algorithm for normalising the logprobs — perhaps a heuristic that upweights earlier tokens relative to later tokens in the average calculation. But ultimately, I can't see a straightforward way to distinguish options whose tokens are favoured because of the tokens that precede them (e.g., 'organ', 'omet', 'al', 'lic') from options whose tokens are favoured because they actually represent the best response. |
Looking into this again, the results when using regex are better than I originally suggested (not sure why). Reworking the example to use regex: import sglang as sgl
import textwrap
# Define answer choices with overlapping substrings and tokenised forms
# assumes llama 3 8B tokeniser
choices_and_tokenised_forms = [
("organ", ["organ"]),
("organism", ["organ", "ism"]),
("organisation", ["organisation"]),
("organelle", ["org", "ane", "lle"]),
("organometallic", ["organ", "omet", "al", "lic"]),
]
choices = [c for c, _ in choices_and_tokenised_forms]
# Define the categorisation question
template = "What category does '{input}' belong to? {choices}"
# Generate the (optional) system prompt with few-shot examples
sys_prompt = ""
for example in [
("ribosome", "organelle"),
("liver", "organ"),
("Google", "organisation"),
("ferrocene", "organometallic"),
("human", "organism"),
]:
sys_prompt += "user:" + template.format(input=example[0], choices=choices)
sys_prompt += f"\nassisant:{example[1]}\n\n"
@sgl.function
def run(s, input: str, show_few_shot_examples: bool = False):
if show_few_shot_examples:
s += sgl.system(f"You categorise things.\n\n ##Examples\n{sys_prompt}")
s += sgl.user(template.format(input=input, choices=choices, temperature=0))
s += sgl.assistant(sgl.gen("answer", regex="(" + "|".join(choices) + ")", temperature=0))
def format_results(state, input):
return f" '{input}' categorised as: '{state['answer']}'"
sgl.set_default_backend(sgl.RuntimeEndpoint("http:https://localhost:30000"))
for include_examples in [False, True]:
print(f"Show few-shot examples in context = {include_examples}")
for input in ["heart", "nucleus", "Microsoft", "mouse", "trimethylboron"]:
state = run(input, show_few_shot_examples=include_examples)
print(textwrap.indent(format_results(state, input), " "))
print("-" * 120) I get:
Without few-shot examples, only one answer is incorrect ( |
Problem
I've noticed that the
gen(choices=[...])
functionality sometimes performs poorly, even for simple tasks. This is due to a flawed normalised logprobs calculation. The calculation biases options that comprise more tokens, where the latter tokens are highly predictable given the prior tokens.Reproducible Example
This is most easily seen in choices with token overlap, so I've constructed a contrived example that illustrates this. The outputs are generated with llama 3 8B instruct, which should breeze through this task under normal circumstances.
Outputs:
The second set of results yields the expected categorisations.
Explanation
We see that only 1/5 answers are correct in the first set of results. Not coincidentally, the only correctly answered question (
'trimethylboron' categorised as: 'organometallic'
) is the one where the correct answer has the most tokens.The first prefill token, common across all options, is
":"
. I'm not actually sure why this is present — something to do with token healing? Is this coming from theassistant:
role prefix? Regardless, it's not important as it's consistent across all options, and it's not responsible for the poor performance (although does skew the logprobs calculations in unpredictable ways).Inspecting the prefill token logprobs for the "organometallic" responses is instructive. Even if the
["organ", "omet"]
tokens are relatively disfavoured, the["al", "lic"]
tokens are essentially guaranteed once you have the"organomet-"
substring. The normalised logprobs calculation is a simple average of the prefill token logprobs, which means the["al", "lic"]
tokens massively inflate the score, even if "organometallic" is obviously wrong given the prior context.The second set of results — which provides in-context few-shot examples — does rectify this with 5/5 correct answers. It seems that showing the model expected outputs leads to tokens beyond
"organ"
, such as"omet"
, being sufficiently penalised to avoid the problem. It is surprising that the model requires this level of priming for such a simple task, however (even without the few-shot examples, the model is told the permitted options in the user prompt).Other Observations
"Answer: "
doesn't help, but does result in prefill tokens that only correspond to the choices and nothing else (i.e. no":"
prior token, or similar). Why? The inconsistent presence/absence of prior tokens skews the scores and can lead to erratic selection behaviour when small tweaks are made to the prompt prefixes.gen(regex="(" + "|".join(choices) + ")")
), thinking this would resolve the issue with simple greedy token selection. But this also performs poorly (and extremely unpredictably, withouttemperature=0
).Suggestions
choices
decoding option could help. That said, I'm not sure why I still get poor outputs when I attempt to simulate this viagen(regex=...)
.The text was updated successfully, but these errors were encountered: