Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to compute the perplexity only on the answer? #1370

Open
Luobots opened this issue Jan 30, 2024 · 7 comments
Open

How to compute the perplexity only on the answer? #1370

Luobots opened this issue Jan 30, 2024 · 7 comments
Labels
asking questions For asking for clarification / support on library usage.

Comments

@Luobots
Copy link

Luobots commented Jan 30, 2024

I am trying to calculate the perplexity on minerva_math, here are my task yaml config.

group:
  - math_word_problems_ppl
task: minerva_math_algebra_ppl
dataset_path: EleutherAI/hendrycks_math
process_docs: !function utils.process_docs
dataset_name: algebra
output_type: loglikelihood_rolling
training_split: train
test_split: test
doc_to_text:  !function utils.doc_to_text
doc_to_target: "{{solution}}"
process_results: !function utils.process_results
should_decontaminate: true
doc_to_decontamination_query: "{{prolem}}"
metric_list:
  - metric: word_perplexity
  - metric: byte_perplexity
  - metric: bits_per_byte
metadata:
  version: 1.0

The doc_to_text and the process_docs functions are the same as the minerva_math.
The process_results function is below like wikitext:

def process_results(doc, results):
    (loglikelihood,) = results
    _words = len(re.split(r"\s+", doc["solution"]))
    _bytes = len(doc["solution"].encode("utf-8"))
    return {
        "word_perplexity": (loglikelihood, _words),
        "byte_perplexity": (loglikelihood, _bytes),
        "bits_per_byte": (loglikelihood, _bytes),
    }

My script is:

lm_eval --model vllm \
    --model_args pretrained=${MODEL},dtype=auto \
    --tasks math_word_problems_ppl \
    --batch_size auto \
    --log_samples \
    --output_path ${OUT}

When I get deep into the vllm_causallms.py, the _loglikelihood_tokens function's inp only has the tokenized_target, the context_enc is [2]. If I have nothing wrong, I think the model didn't see the doc_to_text but only see the doc_to_target?

def _loglikelihood_tokens(
    self,
    requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
    disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]:
    res = []

    def _collate(x):
        toks = x[1] + x[2]
        return -len(toks), tuple(toks)

    # Reorder requests by length and batch
    re_ord = Collator(requests, sort_fn=_collate)
    chunks = re_ord.get_batched(
        n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
    )

    pbar = tqdm(total=len(requests), disable=disable_tqdm)
    for chunk in chunks:
        inputs = []
        ctxlens = []
        for cache_key, context_enc, continuation_enc in chunk:
            inp = (context_enc + continuation_enc)[-(self.max_length) :]
            ctxlen = len(context_enc) - max(
                0, len(context_enc) + len(continuation_enc) - (self.max_length)
            )

            inputs.append(inp)
            ctxlens.append(ctxlen)

        outputs = self._model_generate(requests=inputs, generate=False)

        for output, ctxlen, (cache_key, _, _), inp in zip(
            outputs, ctxlens, chunk, inputs
        ):
            answer = self._parse_logprobs(
                tokens=inp,
                outputs=output,
                ctxlen=ctxlen,
            )

            res.append(answer)

            # partial caching
            if cache_key is not None:
                self.cache_hook.add_partial("loglikelihood", cache_key, answer)
            pbar.update(1)
    pbar.close()
    return re_ord.get_original(res)

However, I want to compute the perplexity like LLM's SFT procedure:

input_ids: [2, 3, 4, ...] # len 512
label: [-100, -100, -100, ..., 345, 456, 567, 789] # len input_ids_len + answer_len, the len of -100 is the same as the input_ids_len

We only compute the model's output on labels. Just like this tutorial: https://huggingface.co/docs/transformers/main/en/perplexity

@Luobots
Copy link
Author

Luobots commented Jan 30, 2024

Maybe

doc_to_text:  !function utils.doc_to_text # result is {{prompted_input}}
doc_to_target: "{{prompted_input}}\n{{solution}}"

is correct?
But I remember I came across a similar situation.
I will try latter.

@Luobots
Copy link
Author

Luobots commented Jan 30, 2024

Maybe

doc_to_text:  !function utils.doc_to_text # result is {{prompted_input}}
doc_to_target: "{{prompted_input}}\n{{solution}}"

is correct? But I remember I came across a similar situation. I will try latter.

OK, It is not correct too, ctx_len is 1 too...😭

@haileyschoelkopf
Copy link
Collaborator

You should be able to do this without a custom process_results(), by following Lambada: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/lambada/lambada_openai.yaml

the relevant bits:

output_type: loglikelihood
doc_to_text: "<input string>"
doc_to_target: <gold answer string>"
metric_list:
  - metric: perplexity
    aggregation: perplexity
    higher_is_better: false

@haileyschoelkopf haileyschoelkopf added the asking questions For asking for clarification / support on library usage. label Jan 30, 2024
@Luobots
Copy link
Author

Luobots commented Jan 30, 2024

You should be able to do this without a custom process_results(), by following Lambada: main/lm_eval/tasks/lambada/lambada_openai.yaml

the relevant bits:

output_type: loglikelihood
doc_to_text: "<input string>"
doc_to_target: <gold answer string>"
metric_list:
  - metric: perplexity
    aggregation: perplexity
    higher_is_better: false

Thank you for reply. I use the config below like Lambada:

group:
  - math_word_problems_ppl
task: minerva_math_algebra_ppl
dataset_path: EleutherAI/hendrycks_math
process_docs: !function utils.process_docs
dataset_name: algebra
output_type: loglikelihood
training_split: train
test_split: test
doc_to_text:  !function utils.doc_to_text
doc_to_target: "\n{{solution}}"
should_decontaminate: true
doc_to_decontamination_query: "{{prolem}}"
metric_list:
  - metric: perplexity
    aggregation: perplexity
metadata:
  version: 1.0

And my doc_to_text example is PROMPT + "\n\n" + "Problem:" + "\n" + doc["problem"] + "\n\n" + "Solution:" like the normal minerva_math.
My doc_to_target is \n + doc['solution'].
The model is mistral-7B-v0.1.
When I ran the script, I got the result:

Tasks Version Filter n-shot Metric Value Stderr
minerva_math_algebra_ppl 1 none 0 perplexity 8837890110290426318812505976351990739806158061568 ± 31262258842398081354462387169045680497025906536087552

It almost get overflow...

When I change the aggregation to mean, I got a minus number:

Tasks Version Filter n-shot Metric Value Stderr
minerva_math_algebra_ppl 1 none 0 perplexity -112.7031 ± 2.8005

@haileyschoelkopf
Copy link
Collaborator

Ah, I think that this is not normalizing perplexity by length, possibly!

I believe the following code snippet may fix this: https://github.com/EleutherAI/lm-evaluation-harness/blob/1554066c9532a7a1cf171b02a5ddaa5bb38f1b78/lm_eval/api/task.py#L1041C1-L1047C14

        if self.OUTPUT_TYPE == "loglikelihood":
            results = results[0]
            ll, is_greedy = results
            return {
                **({"perplexity": ll} if "perplexity" in use_metric else {}),
                **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
            }

by

        if self.OUTPUT_TYPE == "loglikelihood":
            results = results[0]
            ll, is_greedy = results
            _bytes = self.count_bytes(self.doc_to_target(doc))
            return {
                **({"perplexity": ll} if "perplexity" in use_metric else {}),
                **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
                **(
                    {"byte_perplexity": (ll, _bytes)}
                    if "byte_perplexity" in use_metric
                    else {}
                ),
            }

if you use

metric_list:
  - metric: byte_perplexity
    aggregation: byte_perplexity
    higher_is_better: false

in the config.

If this is indeed the solution, I will push a fix for this, for others who may want to calculate normalized target loglikelihood.

@Luobots
Copy link
Author

Luobots commented Jan 30, 2024

Ah, I think that this is not normalizing perplexity by length, possibly!

I believe the following code snippet may fix this: 1554066/lm_eval/api/task.py#L1041C1-L1047C14

        if self.OUTPUT_TYPE == "loglikelihood":
            results = results[0]
            ll, is_greedy = results
            return {
                **({"perplexity": ll} if "perplexity" in use_metric else {}),
                **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
            }

by

        if self.OUTPUT_TYPE == "loglikelihood":
            results = results[0]
            ll, is_greedy = results
            _bytes = self.count_bytes(self.doc_to_target(doc))
            return {
                **({"perplexity": ll} if "perplexity" in use_metric else {}),
                **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
                **(
                    {"byte_perplexity": (ll, _bytes)}
                    if "byte_perplexity" in use_metric
                    else {}
                ),
            }

if you use

metric_list:
  - metric: byte_perplexity
    aggregation: byte_perplexity
    higher_is_better: false

in the config.

If this is indeed the solution, I will push a fix for this, for others who may want to calculate normalized target loglikelihood.

Thank you!! Yes, I think it is not normalized too, but I am not familiar with the code.
The snippet helped me a lot.
But I think aggregation should be weighted_perplexity or default (weighted_perplexity, too). (We don't have an aggregation func named byte_perplexity)
I ran the fix on minerva_math_algebra, and I got 1.3538, I think it is a normal result.

Tasks Version Filter n-shot Metric Value Stderr
minerva_math_algebra_ppl 1 none 0 byte_perplexity 1.3538 ± N/A

@haileyschoelkopf
Copy link
Collaborator

Yes, I think this looks quite reasonable! Very similar to our results on Llemma for this eval.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
asking questions For asking for clarification / support on library usage.
Projects
None yet
Development

No branches or pull requests

2 participants