Skip to content

Commit

Permalink
fix bugs, add dataset col checks
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristyKoh committed Apr 12, 2023
1 parent 2da069a commit 4c6d344
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
3 changes: 1 addition & 2 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,9 @@ def get_splits() -> SplitDict:
model_cfg = AutoConfig.from_pretrained(cfg.model)
num_variants = cfg.prompts.num_variants

# extraneous, remove ?
# Retrieve info, used to get splits
ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)
# ? end

layer_cols = {
f"hidden_{layer}": Array3D(
Expand Down
33 changes: 30 additions & 3 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Dataset,
Features,
load_dataset,
load_dataset_builder,
)
from datasets.distributed import split_dataset_by_node
from simple_parsing.helpers import Serializable, field
Expand Down Expand Up @@ -84,14 +85,41 @@ def __post_init__(self):
"combined_templates", self.combined_template_path
)
combined_prompter.templates = {}
prev_num_features = 0
prev_num_label_classes = 0
for ds_string in self.datasets:
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)
# TODO: Verify that cols are same; if not, warn that templates
# could not be combined.

# Verify that number of features and number of classes for ClassLabel
# are the same across datasets.
ds_builder = load_dataset_builder(ds_name, config_name or None)
num_features = len(ds_builder.info.features)
if prev_num_features > 0 and num_features != prev_num_features:
print("WARNING: Datasets do not have the same number of features;",
f"{ds_name} has {num_features} features while prev has",
f"{prev_num_features}. Prompting datasets separately.")
self.combined_template_path = ""
break
prev_num_features = num_features
num_classes = ds_builder.info.features['label'].num_classes
if prev_num_label_classes > 0 and num_classes != prev_num_label_classes:
print("WARNING: Datasets do not have the same number of ClassLabel",
f"classes; {ds_name} has {num_classes} classes while prev has",
f"{prev_num_label_classes}. Prompting datasets separately.")
self.combined_template_path = ""
break
prev_num_label_classes = num_classes

# Once verified, merge templates.
combined_prompter.merge_templates_from(prompter)

# Write to file if successfully merged all prompts.
if self.combined_template_path:
print("Total number of templates: ", len(combined_prompter.templates))
combined_prompter.write_to_file()
print("Saved to promptsource/templates/combined_templates/" +
f"{self.combined_template_path}.yaml")


def load_prompts(
Expand Down Expand Up @@ -177,7 +205,6 @@ def load_prompts(
if num_variants == -1
else min(num_variants, min_num_templates)
)
print()
assert num_variants > 0
if rank == 0:
print(f"Using {num_variants} variants of each prompt")
Expand Down

0 comments on commit 4c6d344

Please sign in to comment.