Skip to content

Commit

Permalink
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
Browse files Browse the repository at this point in the history
… spar_mt_prompt_invar
  • Loading branch information
ChristyKoh committed Apr 12, 2023
2 parents b78355f + 53d186b commit 3486ded
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,30 +97,36 @@ def __post_init__(self):
ds_builder = load_dataset_builder(ds_name, config_name or None)
num_features = len(ds_builder.info.features)
if prev_num_features > 0 and num_features != prev_num_features:
print("WARNING: Datasets do not have the same number of features;",
f"{ds_name} has {num_features} features while prev has",
f"{prev_num_features}. Prompting datasets separately.")
print(
"WARNING: Datasets do not have the same number of features;",
f"{ds_name} has {num_features} features while prev has",
f"{prev_num_features}. Prompting datasets separately.",
)
self.combined_template_path = ""
break
prev_num_features = num_features
num_classes = ds_builder.info.features['label'].num_classes
num_classes = ds_builder.info.features["label"].num_classes
if prev_num_label_classes > 0 and num_classes != prev_num_label_classes:
print("WARNING: Datasets do not have the same number of ClassLabel",
f"classes; {ds_name} has {num_classes} classes while prev has",
f"{prev_num_label_classes}. Prompting datasets separately.")
print(
"WARNING: Datasets do not have the same number of ClassLabel",
f"classes; {ds_name} has {num_classes} classes while prev has",
f"{prev_num_label_classes}. Prompting datasets separately.",
)
self.combined_template_path = ""
break
prev_num_label_classes = num_classes

# Once verified, merge templates.
combined_prompter.merge_templates_from(prompter)

# Write to file if successfully merged all prompts.
if self.combined_template_path:
print("Total number of templates: ", len(combined_prompter.templates))
combined_prompter.write_to_file()
print("Saved to promptsource/templates/combined_templates/" +
f"{self.combined_template_path}.yaml")
print(
"Saved to promptsource/templates/combined_templates/"
+ f"{self.combined_template_path}.yaml"
)


def load_prompts(
Expand Down

0 comments on commit 3486ded

Please sign in to comment.