From 7b3be761928931f0f0516631f6903eef328648fa Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Thu, 20 Apr 2023 13:07:34 -0700 Subject: [PATCH] Revert "spar mt prompt invar" --- README.md | 5 -- elk/evaluation/evaluate.py | 14 ++---- elk/extraction/extraction.py | 2 - elk/extraction/prompt_dataset.py | 0 elk/extraction/prompt_loading.py | 80 ++------------------------------ elk/promptsource/templates.py | 9 ---- 6 files changed, 7 insertions(+), 103 deletions(-) create mode 100644 elk/extraction/prompt_dataset.py diff --git a/README.md b/README.md index d3f55e5b..f7165573 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,6 @@ The following command will evaluate the probe from the run naughty-northcutt on elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb ``` -For prompt invariance across multiple datasets, use the `--combined_template_output_path` command line argument, which will create a new `templates.yaml` file with templates from all the datasets. -```bash -elk elicit bigscience/bloomz-560m christykoh/ag_news_pt ag_news --combined_template_output_path=spar_w/ag_news -``` - The following runs `elicit` on the Cartesian product of the listed models and datasets, storing it in a special folder ELK_DIR/sweeps/. Moreover, `--add_pooled` adds an additional dataset that pools all of the datasets together. ```bash diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 793f6b3e..fa350b18 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -42,21 +42,13 @@ class Eval(Serializable): num_gpus: int = -1 out_dir: Path | None = None skip_supervised: bool = False - combine_evals: bool = False def execute(self): - datasets = self.data.prompts.datasets - transfer_dir = elk_reporter_dir() / self.source / "transfer_eval" - if self.combine_evals: - run = Evaluate(cfg=self, out_dir=transfer_dir / ", ".join(datasets)) - else: - # eval on each dataset separately - for dataset in datasets: - self.data.prompts.datasets = [dataset] - run = Evaluate(cfg=self, out_dir=transfer_dir / dataset) - run.evaluate() + for dataset in self.data.prompts.datasets: + run = Evaluate(cfg=self, out_dir=transfer_dir / dataset) + run.evaluate() @dataclass diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 2d7e64e6..7d22f9ab 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -113,7 +113,6 @@ def extract_hiddens( stream=p_cfg.stream, rank=rank, world_size=world_size, - combined_template_output_path=cfg.prompts.combined_template_output_path, ) # this dataset is already sharded, but hasn't been truncated to max_examples model = instantiate_model( @@ -284,7 +283,6 @@ def get_splits() -> SplitDict: model_cfg = AutoConfig.from_pretrained(cfg.model) - # Retrieve info, used to get splits ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") info = get_dataset_config_info(ds_name, config_name or None) diff --git a/elk/extraction/prompt_dataset.py b/elk/extraction/prompt_dataset.py new file mode 100644 index 00000000..e69de29b diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index cca0c042..c56eeffd 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -2,7 +2,6 @@ from copy import deepcopy from dataclasses import dataclass from itertools import zip_longest -from os.path import exists from random import Random from typing import Any, Iterator, Literal, Optional @@ -10,7 +9,6 @@ Dataset, Features, load_dataset, - load_dataset_builder, ) from datasets.distributed import split_dataset_by_node from simple_parsing.helpers import Serializable, field @@ -29,9 +27,8 @@ class PromptConfig(Serializable): """ Args: - datasets: List of space-delimited names of the HuggingFace datasets to use, e.g. - [`"super_glue boolq", "imdb"]`. - balance: Whether to force class balance in the dataset using undersampling. + dataset: List of space-delimited names of the HuggingFace dataset to use, e.g. + `"super_glue boolq"` or `"imdb"`. data_dir: The directory to use for caching the dataset. Defaults to `~/.cache/huggingface/datasets`. label_column: The column containing the labels. By default, we infer this from @@ -44,13 +41,9 @@ class PromptConfig(Serializable): num_shots: The number of examples to use in few-shot prompts. If zero, prompts are zero-shot. Defaults to 0. num_variants: The number of prompt templates to apply to each predicate upon - call to __getitem__. Use -1 to apply all available templates. Defaults to - -1. + call to __getitem__. Use -1 to apply all available templates. Defaults to 1. seed: The seed to use for prompt randomization. Defaults to 42. stream: Whether to stream the dataset from the Internet. Defaults to False. - combined_template_output_path: Path to save a combined template file to, when - applying prompt invariance across multiple datasets. Interpreted as a - subpath of `combined_paths` in the templates dir. Defaults to empty string. """ datasets: list[str] = field(positional=True) @@ -62,7 +55,6 @@ class PromptConfig(Serializable): num_variants: int = -1 seed: int = 42 stream: bool = False - combined_template_output_path: str = "" def __post_init__(self): if len(self.max_examples) > 2: @@ -73,8 +65,6 @@ def __post_init__(self): if not self.max_examples: self.max_examples = [int(1e100)] - self.combine_templates() - # Broadcast the limit to all splits if len(self.max_examples) == 1: self.max_examples *= 2 @@ -96,62 +86,6 @@ def __post_init__(self): f" but got {len(self.label_columns)}" ) - def combine_templates(self): - if not self.combined_template_output_path: - return - - print( - "Copying templates across datasets to combined_templates/ " - + f"{self.combined_template_output_path}/templates.yaml" - ) - combined_prompter = DatasetTemplates( - "combined_templates", self.combined_template_output_path - ) - combined_prompter.templates = {} - ref_ds_builder = None - for i, ds_string in enumerate(self.datasets): - ds_name, _, config_name = ds_string.partition(" ") - ds_builder = load_dataset_builder(ds_name, config_name or None) - if i == 0: - # Set first dataset as reference - ref_ds_builder = ds_builder - elif not self.verify_cols(ref_ds_builder, ds_builder, ds_name): - return - - # Once verified, merge templates. - prompter = DatasetTemplates(ds_name, config_name) - combined_prompter.merge_templates_from(prompter) - print("Total number of templates: ", len(combined_prompter.templates)) - combined_prompter.write_to_file() - print( - "Saved to promptsource/templates/combined_templates/" - + f"{self.combined_template_output_path}.yaml" - ) - - def verify_cols(self, ref_ds_builder, ds_builder, ds_name) -> bool: - """Verify that number of features and number of classes for ClassLabel - match the expected values. - """ - expected_features = len(ref_ds_builder.info.features) - expected_classes = ref_ds_builder.info.features["label"].num_classes - num_features = len(ds_builder.info.features) - num_classes = ds_builder.info.features["label"].num_classes - if expected_features > 0 and num_features != expected_features: - print( - "WARNING: Datasets do not have the same number of features;", - f"{ds_name} has {num_features} features while first dataset has", - f"{expected_features}. Prompting datasets separately.", - ) - return False - if expected_classes > 0 and num_classes != expected_classes: - print( - "WARNING: Datasets do not have the same number of ClassLabel classes", - f"{ds_name} has {num_classes} classes while first dataset has", - f"{expected_classes}. Prompting datasets separately.", - ) - return False - return True - def explode(self) -> list["PromptConfig"]: """Explode the config into a list of configs, one for each dataset.""" copies = [] @@ -179,7 +113,6 @@ def load_prompts( stream: bool = False, rank: int = 0, world_size: int = 1, - combined_template_output_path: str = "", ) -> Iterator[dict]: """Load a dataset full of prompts generated from the specified dataset. @@ -198,12 +131,7 @@ def load_prompts( An iterable of prompt dictionaries. """ ds_name, _, config_name = ds_string.partition(" ") - - prompter = None - if combined_template_output_path and exists(combined_template_output_path): - prompter = DatasetTemplates("combined_templates", combined_template_output_path) - else: - prompter = DatasetTemplates(ds_name, config_name) + prompter = DatasetTemplates(ds_name, config_name) ds_dict = assert_type( dict, load_dataset(ds_name, config_name or None, streaming=stream) diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index 38d3f87f..ea4e9196 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -543,15 +543,6 @@ def delete_folder(self) -> None: if len(os.listdir(base_folder)) == 0: rmtree(base_folder) - def merge_templates_from(self, src: "DatasetTemplates"): - """ - Merge templates from src. - """ - for template in src.templates.values(): - template_id = str(uuid.uuid4()) - self.templates[template_id] = template - self.sync_mapping() - def __getitem__(self, template_key: str) -> "Template": return self.templates[self.name_to_id_mapping[template_key]]