Skip to content

Commit

Permalink
Merge pull request #26 from bigscience-workshop/max-gen-fix
Browse files Browse the repository at this point in the history
Fix max generation limit
  • Loading branch information
StellaAthena committed Apr 29, 2022
2 parents 22155f7 + eda365f commit ad23a86
Show file tree
Hide file tree
Showing 13 changed files with 17 additions and 118 deletions.
87 changes: 8 additions & 79 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,10 @@ def _collate(x):
).to(self.device)

if max_generation_length is None:
max_length = context_enc.shape[1] + self.max_gen_toks
max_length = self.max_gen_tok
else:
max_length = min(
max_generation_length, context_enc.shape[1] + self.max_gen_toks
)
max_length = max_generation_length

cont = self._model_generate(
context_enc,
max_length,
Expand Down Expand Up @@ -595,78 +594,6 @@ def fewshot_description(self):
)
return ""

@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)

description = description + "\n\n" if description else ""

if num_fewshot == 0:
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)

fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]

# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator = "\n###\n"

labeled_examples = (
example_separator.join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ example_separator
)

example = self.doc_to_text(doc)
return description + labeled_examples + example


class PromptSourceTask(Task):
"""These are the metrics from promptsource that we have
Expand All @@ -691,10 +618,12 @@ def __init__(
self.prompt = prompt
self.save_examples = save_examples

def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.

By default, its "\n###\n".
def stopping_criteria(self) -> Optional[str]:
"""
Denote where the generation should end based on the few-shot example
separator: "\n###\n".
TODO: Handle other separators in the future.
"""
return "\n###\n"

Expand Down
2 changes: 1 addition & 1 deletion lm_eval/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa

def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)

max_length = max_length + context.size(1)
if num_fewshot == 0:
generations = self.gpt2.generate(
context,
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/models/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa

def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)

max_length = max_length + context.size(1)
if num_fewshot == 0:
generations = self.gptj.generate(
context,
Expand Down
1 change: 0 additions & 1 deletion lm_eval/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa

def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)

if num_fewshot == 0:
generations = self.t5.generate(
context,
Expand Down
2 changes: 0 additions & 2 deletions lm_eval/tasks/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ def parse_answer(cls, answer):
# """
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
# def stopping_criteria(self):
# return "."

def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
Expand Down
6 changes: 2 additions & 4 deletions lm_eval/tasks/e2e_nlg_cleaned.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ def test_docs(self):
def max_generation_length(self):
return 64

# def stopping_criteria(self):
# return '\n\n'

def invalid_doc_for_prompt(self, doc) -> bool:
"""The QA prompts are not applicable to all the examples, we want to filter these out."""
return self.prompt.name.endswith("_qa") or self.prompt.name == "family_friendly_yes_no"
Expand All @@ -73,7 +70,7 @@ def doc_to_text(self, doc) -> str:
text = self.prompt.apply(doc)[0]
return text

def construct_requests(self, doc, ctx):
def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Expand All @@ -90,6 +87,7 @@ def construct_requests(self, doc, ctx):
request_args = {
"stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(),
"num_fewshot": args["num_fewshot"],
}

# Skip examples for which the templates are not applicable
Expand Down
6 changes: 0 additions & 6 deletions lm_eval/tasks/gem_asset_turk.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,9 @@ def validation_docs(self):
def test_docs(self):
return self.dataset[str(self.SPLIT)]

# def stopping_criteria(self):
# return None

def max_generation_length(self):
return 200

# def higher_is_better(self):
# return {"bleu": True, "rouge": True}


class AssetTest(AssetTurk):
SPLIT = "test_asset"
Expand Down
6 changes: 0 additions & 6 deletions lm_eval/tasks/gem_mlsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]

def stopping_criteria(self):
return "."

class GEMMLSUMEs(GEMMLSUMEsBase):
'''this is for train/validation/test'''
SPLIT = ''
Expand Down Expand Up @@ -98,9 +95,6 @@ def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]

def stopping_criteria(self):
return "."

class GEMMLSUMDe(GEMMLSUMDeBase):
'''this is for train/validation/test'''
SPLIT = ''
Expand Down
6 changes: 0 additions & 6 deletions lm_eval/tasks/gem_webnlg.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,9 @@ def test_docs(self):
else:
return self.dataset["test"]

# def stopping_criteria(self):
# return None

def max_generation_length(self):
return 250

# def higher_is_better(self):
# return {"bleu": True, "rouge": True}


class WebNLGRu(WebNLG):
DATASET_NAME = "ru"
Expand Down
2 changes: 0 additions & 2 deletions lm_eval/tasks/gem_xsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def has_validation_docs(self):
def has_test_docs(self):
return True

def stopping_criteria(self):
return '.'
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
Expand Down
3 changes: 0 additions & 3 deletions lm_eval/tasks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,6 @@ def has_validation_docs(self):
def has_test_docs(self):
return False

# def stopping_criteria(self):
# return "\n###\n"

def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
Expand Down
3 changes: 0 additions & 3 deletions lm_eval/tasks/wino_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def validation_docs(self):
def test_docs(self):
return self.dataset["test"]

# def stopping_criteria(self):
# return "\n"

def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
Expand Down
9 changes: 5 additions & 4 deletions templates/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ def test_docs(self):
# named differently than the default `"test"`.
return self.dataset["test"]

def stopping_criteria(self):
# Only define this method when you want to control few-shot generations on specific tokens.
# The default is set to '\n###\n'.
def max_generation_length(self):
# Define this method when you want to control the length of few-shot
# generations on specific tokens. The default is `None` which gets mapped
# to a model's default max generation token length. E.g. see `lm_eval/models/gpt2.py:max_gen_toks()`
# NOTE: You may delete this function if the task does not required generation.
return "\n###\n"
return None

def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Expand Down

0 comments on commit ad23a86

Please sign in to comment.