Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 12, 2023
1 parent 066cd44 commit 846b78c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def extract_hiddens(
stream=cfg.prompts.stream,
rank=rank,
world_size=world_size,
combined_template_path=cfg.prompts.combined_template_path
combined_template_path=cfg.prompts.combined_template_path,
) # this dataset is already sharded, but hasn't been truncated to max_examples

model = instantiate_model(
Expand Down
23 changes: 16 additions & 7 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,26 @@ def __post_init__(self):

# Combining prompts
if self.combined_template_path:
print("Copying templates across datasets to combined_templates/ " +
f"{self.combined_template_path}/templates.yaml")
combined_prompter = DatasetTemplates("combined_templates", self.combined_template_path)
print(
"Copying templates across datasets to combined_templates/ "
+ f"{self.combined_template_path}/templates.yaml"
)
combined_prompter = DatasetTemplates(
"combined_templates", self.combined_template_path
)
combined_prompter.templates = {}
for ds_string in self.datasets:
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)
# TODO: Verify that cols are same; if not, warn that templates could not be combined.
combined_prompter.merge_templates_from(prompter)
# combined_prompter.templates.update(prompter.get_templates_with_new_uuids())
print("Total number of templates gathered: ", len(combined_prompter.templates))
print(
"Total number of templates gathered: ", len(combined_prompter.templates)
)
combined_prompter.write_to_file()


def load_prompts(
*dataset_strings: str,
num_shots: int = 0,
Expand All @@ -97,7 +104,7 @@ def load_prompts(
stream: bool = False,
rank: int = 0,
world_size: int = 1,
combined_template_path: str = ""
combined_template_path: str = "",
) -> Iterator[dict]:
"""Load a dataset full of prompts generated from the specified datasets.
Expand Down Expand Up @@ -157,9 +164,11 @@ def load_prompts(

raw_datasets.append(split)
train_datasets.append(train_ds)

if combined_template_path:
combined_prompter = DatasetTemplates("combined_templates", combined_template_path)
combined_prompter = DatasetTemplates(
"combined_templates", combined_template_path
)
prompters = [combined_prompter] * len(dataset_strings)

min_num_templates = min(len(prompter.templates) for prompter in prompters)
Expand Down

0 comments on commit 846b78c

Please sign in to comment.