Skip to content

Commit

Permalink
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
Browse files Browse the repository at this point in the history
… spar_mt_prompt_invar
  • Loading branch information
ChristyKoh committed Apr 12, 2023
2 parents 1aecd7e + 846b78c commit b7bbee0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def extract_hiddens(
stream=cfg.prompts.stream,
rank=rank,
world_size=world_size,
combined_template_path=cfg.prompts.combined_template_path
combined_template_path=cfg.prompts.combined_template_path,
) # this dataset is already sharded, but hasn't been truncated to max_examples

model = instantiate_model(
Expand Down
20 changes: 13 additions & 7 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ def __post_init__(self):

# Combining prompts
if self.combined_template_path:
print("Copying templates across datasets to combined_templates/ " +
f"{self.combined_template_path}/templates.yaml")
combined_prompter = DatasetTemplates("combined_templates",
self.combined_template_path)
print(
"Copying templates across datasets to combined_templates/ "
+ f"{self.combined_template_path}/templates.yaml"
)
combined_prompter = DatasetTemplates(
"combined_templates", self.combined_template_path
)
combined_prompter.templates = {}
for ds_string in self.datasets:
ds_name, _, config_name = ds_string.partition(" ")
Expand All @@ -90,6 +93,7 @@ def __post_init__(self):
print("Total number of templates: ", len(combined_prompter.templates))
combined_prompter.write_to_file()


def load_prompts(
*dataset_strings: str,
num_shots: int = 0,
Expand All @@ -99,7 +103,7 @@ def load_prompts(
stream: bool = False,
rank: int = 0,
world_size: int = 1,
combined_template_path: str = ""
combined_template_path: str = "",
) -> Iterator[dict]:
"""Load a dataset full of prompts generated from the specified datasets.
Expand Down Expand Up @@ -159,9 +163,11 @@ def load_prompts(

raw_datasets.append(split)
train_datasets.append(train_ds)

if combined_template_path:
combined_prompter = DatasetTemplates("combined_templates", combined_template_path)
combined_prompter = DatasetTemplates(
"combined_templates", combined_template_path
)
prompters = [combined_prompter] * len(dataset_strings)

min_num_templates = min(len(prompter.templates) for prompter in prompters)
Expand Down

0 comments on commit b7bbee0

Please sign in to comment.