Skip to content

Commit

Permalink
rewrite template merging, regenerate prompter every run
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristyKoh committed Apr 12, 2023
1 parent 0c2f5c4 commit 066cd44
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
3 changes: 2 additions & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def extract_hiddens(
split_type=split_type,
stream=cfg.prompts.stream,
rank=rank,
world_size=world_size
world_size=world_size,
combined_template_path=cfg.prompts.combined_template_path
) # this dataset is already sharded, but hasn't been truncated to max_examples

model = instantiate_model(
Expand Down
34 changes: 20 additions & 14 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PromptConfig(Serializable):
call to __getitem__. Use -1 to apply all available templates. Defaults to -1.
seed: The seed to use for prompt randomization. Defaults to 42.
stream: Whether to stream the dataset from the Internet. Defaults to False.
combined_prompter_path: Path to save a combined template file to, when testing
combined_template_path: Path to save a combined template file to, when testing
prompt invariance across multiple datasets, and will be interpreted as a subpath
of `combined_paths` in the promptsource templates dir. Defaults to empty string.
"""
Expand All @@ -58,7 +58,7 @@ class PromptConfig(Serializable):
num_variants: int = -1
seed: int = 42
stream: bool = False
combined_prompter_path: str = ""
combined_template_path: str = ""

def __post_init__(self):
if len(self.max_examples) > 2:
Expand All @@ -74,21 +74,20 @@ def __post_init__(self):
self.max_examples *= 2

# Combining prompts
if self.combined_prompter_path:
if self.combined_template_path:
print("Copying templates across datasets to combined_templates/ " +
f"{self.combined_prompter_path}/templates.yaml")
combined_prompter = DatasetTemplates("combined_templates", self.combined_prompter_path)
f"{self.combined_template_path}/templates.yaml")
combined_prompter = DatasetTemplates("combined_templates", self.combined_template_path)
combined_prompter.templates = {}
for ds_string in self.datasets:
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)
combined_prompter.templates.update(prompter.get_templates_with_new_uuids())
print("len of prompter templates is ", len(combined_prompter.templates))
# TODO: Verify that cols are same; if not, warn that templates could not be combined.
combined_prompter.merge_templates_from(prompter)
# combined_prompter.templates.update(prompter.get_templates_with_new_uuids())
print("Total number of templates gathered: ", len(combined_prompter.templates))
combined_prompter.write_to_file()

# Update datasets reference to use combined prompter
self.datasets = [f"combined_templates {self.combined_prompter_path}"] * len(self.datasets)


def load_prompts(
*dataset_strings: str,
num_shots: int = 0,
Expand All @@ -97,7 +96,8 @@ def load_prompts(
split_type: Literal["train", "val"] = "train",
stream: bool = False,
rank: int = 0,
world_size: int = 1
world_size: int = 1,
combined_template_path: str = ""
) -> Iterator[dict]:
"""Load a dataset full of prompts generated from the specified datasets.
Expand All @@ -124,8 +124,10 @@ def load_prompts(
# templates for any dataset in order to make sure we don't run out of prompts.
for ds_string in dataset_strings:
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)
prompters.append(DatasetTemplates(ds_name, config_name))

if combined_template_path == "":
prompter = DatasetTemplates(ds_name, config_name)
prompters.append(DatasetTemplates(ds_name, config_name))

ds_dict = assert_type(
dict, load_dataset(ds_name, config_name or None, streaming=stream)
Expand Down Expand Up @@ -156,6 +158,10 @@ def load_prompts(
raw_datasets.append(split)
train_datasets.append(train_ds)

if combined_template_path:
combined_prompter = DatasetTemplates("combined_templates", combined_template_path)
prompters = [combined_prompter] * len(dataset_strings)

min_num_templates = min(len(prompter.templates) for prompter in prompters)

num_variants = (
Expand Down
13 changes: 6 additions & 7 deletions elk/promptsource/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,14 @@ def delete_folder(self) -> None:
if len(os.listdir(base_folder)) == 0:
rmtree(base_folder)

def get_templates_with_new_uuids(self) -> dict:
def merge_templates_from(self, src: "DatasetTemplates"):
"""
Generate new uuids for templates, used when merging template datasets.
Merge templates from src.
"""
new_templates = {}
for template in self.templates.values():
template.id = str(uuid.uuid4())
new_templates[template.id] = template
return new_templates
for template in src.templates.values():
template_id = str(uuid.uuid4())
self.templates[template_id] = template
self.sync_mapping()

def __getitem__(self, template_key: str) -> "Template":
return self.templates[self.name_to_id_mapping[template_key]]
Expand Down

0 comments on commit 066cd44

Please sign in to comment.