From 846b78c0f664a48c556231f41a7717cf399b348b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Apr 2023 12:16:15 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/extraction/extraction.py | 2 +- elk/extraction/prompt_loading.py | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index fe70603a..037aa828 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -103,7 +103,7 @@ def extract_hiddens( stream=cfg.prompts.stream, rank=rank, world_size=world_size, - combined_template_path=cfg.prompts.combined_template_path + combined_template_path=cfg.prompts.combined_template_path, ) # this dataset is already sharded, but hasn't been truncated to max_examples model = instantiate_model( diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index c840f212..7f8b5b36 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -75,9 +75,13 @@ def __post_init__(self): # Combining prompts if self.combined_template_path: - print("Copying templates across datasets to combined_templates/ " + - f"{self.combined_template_path}/templates.yaml") - combined_prompter = DatasetTemplates("combined_templates", self.combined_template_path) + print( + "Copying templates across datasets to combined_templates/ " + + f"{self.combined_template_path}/templates.yaml" + ) + combined_prompter = DatasetTemplates( + "combined_templates", self.combined_template_path + ) combined_prompter.templates = {} for ds_string in self.datasets: ds_name, _, config_name = ds_string.partition(" ") @@ -85,9 +89,12 @@ def __post_init__(self): # TODO: Verify that cols are same; if not, warn that templates could not be combined. combined_prompter.merge_templates_from(prompter) # combined_prompter.templates.update(prompter.get_templates_with_new_uuids()) - print("Total number of templates gathered: ", len(combined_prompter.templates)) + print( + "Total number of templates gathered: ", len(combined_prompter.templates) + ) combined_prompter.write_to_file() + def load_prompts( *dataset_strings: str, num_shots: int = 0, @@ -97,7 +104,7 @@ def load_prompts( stream: bool = False, rank: int = 0, world_size: int = 1, - combined_template_path: str = "" + combined_template_path: str = "", ) -> Iterator[dict]: """Load a dataset full of prompts generated from the specified datasets. @@ -157,9 +164,11 @@ def load_prompts( raw_datasets.append(split) train_datasets.append(train_ds) - + if combined_template_path: - combined_prompter = DatasetTemplates("combined_templates", combined_template_path) + combined_prompter = DatasetTemplates( + "combined_templates", combined_template_path + ) prompters = [combined_prompter] * len(dataset_strings) min_num_templates = min(len(prompter.templates) for prompter in prompters)