Skip to content

Commit

Permalink
line len fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristyKoh committed Apr 12, 2023
1 parent 066cd44 commit 1aecd7e
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ class PromptConfig(Serializable):
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot. Defaults to 0.
num_variants: The number of prompt templates to apply to each predicate upon
call to __getitem__. Use -1 to apply all available templates. Defaults to -1.
call to __getitem__. Use -1 to apply all available templates. Defaults to
-1.
seed: The seed to use for prompt randomization. Defaults to 42.
stream: Whether to stream the dataset from the Internet. Defaults to False.
combined_template_path: Path to save a combined template file to, when testing
prompt invariance across multiple datasets, and will be interpreted as a subpath
of `combined_paths` in the promptsource templates dir. Defaults to empty string.
prompt invariance across multiple datasets, and will be interpreted as a
subpath of `combined_paths` in the templates dir. Defaults to empty string.
"""

datasets: list[str] = field(positional=True)
Expand Down Expand Up @@ -77,15 +78,16 @@ def __post_init__(self):
if self.combined_template_path:
print("Copying templates across datasets to combined_templates/ " +
f"{self.combined_template_path}/templates.yaml")
combined_prompter = DatasetTemplates("combined_templates", self.combined_template_path)
combined_prompter = DatasetTemplates("combined_templates",
self.combined_template_path)
combined_prompter.templates = {}
for ds_string in self.datasets:
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)
# TODO: Verify that cols are same; if not, warn that templates could not be combined.
# TODO: Verify that cols are same; if not, warn that templates
# could not be combined.
combined_prompter.merge_templates_from(prompter)
# combined_prompter.templates.update(prompter.get_templates_with_new_uuids())
print("Total number of templates gathered: ", len(combined_prompter.templates))
print("Total number of templates: ", len(combined_prompter.templates))
combined_prompter.write_to_file()

def load_prompts(
Expand Down

0 comments on commit 1aecd7e

Please sign in to comment.