Skip to content

Commit

Permalink
Merge pull request EleutherAI#205 from EleutherAI/revert-184-spar_mt_…
Browse files Browse the repository at this point in the history
…prompt_invar

Revert "spar mt prompt invar"
  • Loading branch information
lauritowal committed Apr 20, 2023
2 parents 89a2122 + 7b3be76 commit 7d7c175
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 103 deletions.
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ The following command will evaluate the probe from the run naughty-northcutt on
elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb
```

For prompt invariance across multiple datasets, use the `--combined_template_output_path` command line argument, which will create a new `templates.yaml` file with templates from all the datasets.
```bash
elk elicit bigscience/bloomz-560m christykoh/ag_news_pt ag_news --combined_template_output_path=spar_w/ag_news
```

The following runs `elicit` on the Cartesian product of the listed models and datasets, storing it in a special folder ELK_DIR/sweeps/<memorable_name>. Moreover, `--add_pooled` adds an additional dataset that pools all of the datasets together.

```bash
Expand Down
14 changes: 3 additions & 11 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,13 @@ class Eval(Serializable):
num_gpus: int = -1
out_dir: Path | None = None
skip_supervised: bool = False
combine_evals: bool = False

def execute(self):
datasets = self.data.prompts.datasets

transfer_dir = elk_reporter_dir() / self.source / "transfer_eval"

if self.combine_evals:
run = Evaluate(cfg=self, out_dir=transfer_dir / ", ".join(datasets))
else:
# eval on each dataset separately
for dataset in datasets:
self.data.prompts.datasets = [dataset]
run = Evaluate(cfg=self, out_dir=transfer_dir / dataset)
run.evaluate()
for dataset in self.data.prompts.datasets:
run = Evaluate(cfg=self, out_dir=transfer_dir / dataset)
run.evaluate()


@dataclass
Expand Down
2 changes: 0 additions & 2 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def extract_hiddens(
stream=p_cfg.stream,
rank=rank,
world_size=world_size,
combined_template_output_path=cfg.prompts.combined_template_output_path,
) # this dataset is already sharded, but hasn't been truncated to max_examples

model = instantiate_model(
Expand Down Expand Up @@ -284,7 +283,6 @@ def get_splits() -> SplitDict:

model_cfg = AutoConfig.from_pretrained(cfg.model)

# Retrieve info, used to get splits
ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)

Expand Down
Empty file.
80 changes: 4 additions & 76 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
from copy import deepcopy
from dataclasses import dataclass
from itertools import zip_longest
from os.path import exists
from random import Random
from typing import Any, Iterator, Literal, Optional

from datasets import (
Dataset,
Features,
load_dataset,
load_dataset_builder,
)
from datasets.distributed import split_dataset_by_node
from simple_parsing.helpers import Serializable, field
Expand All @@ -29,9 +27,8 @@
class PromptConfig(Serializable):
"""
Args:
datasets: List of space-delimited names of the HuggingFace datasets to use, e.g.
[`"super_glue boolq", "imdb"]`.
balance: Whether to force class balance in the dataset using undersampling.
dataset: List of space-delimited names of the HuggingFace dataset to use, e.g.
`"super_glue boolq"` or `"imdb"`.
data_dir: The directory to use for caching the dataset. Defaults to
`~/.cache/huggingface/datasets`.
label_column: The column containing the labels. By default, we infer this from
Expand All @@ -44,13 +41,9 @@ class PromptConfig(Serializable):
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot. Defaults to 0.
num_variants: The number of prompt templates to apply to each predicate upon
call to __getitem__. Use -1 to apply all available templates. Defaults to
-1.
call to __getitem__. Use -1 to apply all available templates. Defaults to 1.
seed: The seed to use for prompt randomization. Defaults to 42.
stream: Whether to stream the dataset from the Internet. Defaults to False.
combined_template_output_path: Path to save a combined template file to, when
applying prompt invariance across multiple datasets. Interpreted as a
subpath of `combined_paths` in the templates dir. Defaults to empty string.
"""

datasets: list[str] = field(positional=True)
Expand All @@ -62,7 +55,6 @@ class PromptConfig(Serializable):
num_variants: int = -1
seed: int = 42
stream: bool = False
combined_template_output_path: str = ""

def __post_init__(self):
if len(self.max_examples) > 2:
Expand All @@ -73,8 +65,6 @@ def __post_init__(self):
if not self.max_examples:
self.max_examples = [int(1e100)]

self.combine_templates()

# Broadcast the limit to all splits
if len(self.max_examples) == 1:
self.max_examples *= 2
Expand All @@ -96,62 +86,6 @@ def __post_init__(self):
f" but got {len(self.label_columns)}"
)

def combine_templates(self):
if not self.combined_template_output_path:
return

print(
"Copying templates across datasets to combined_templates/ "
+ f"{self.combined_template_output_path}/templates.yaml"
)
combined_prompter = DatasetTemplates(
"combined_templates", self.combined_template_output_path
)
combined_prompter.templates = {}
ref_ds_builder = None
for i, ds_string in enumerate(self.datasets):
ds_name, _, config_name = ds_string.partition(" ")
ds_builder = load_dataset_builder(ds_name, config_name or None)
if i == 0:
# Set first dataset as reference
ref_ds_builder = ds_builder
elif not self.verify_cols(ref_ds_builder, ds_builder, ds_name):
return

# Once verified, merge templates.
prompter = DatasetTemplates(ds_name, config_name)
combined_prompter.merge_templates_from(prompter)
print("Total number of templates: ", len(combined_prompter.templates))
combined_prompter.write_to_file()
print(
"Saved to promptsource/templates/combined_templates/"
+ f"{self.combined_template_output_path}.yaml"
)

def verify_cols(self, ref_ds_builder, ds_builder, ds_name) -> bool:
"""Verify that number of features and number of classes for ClassLabel
match the expected values.
"""
expected_features = len(ref_ds_builder.info.features)
expected_classes = ref_ds_builder.info.features["label"].num_classes
num_features = len(ds_builder.info.features)
num_classes = ds_builder.info.features["label"].num_classes
if expected_features > 0 and num_features != expected_features:
print(
"WARNING: Datasets do not have the same number of features;",
f"{ds_name} has {num_features} features while first dataset has",
f"{expected_features}. Prompting datasets separately.",
)
return False
if expected_classes > 0 and num_classes != expected_classes:
print(
"WARNING: Datasets do not have the same number of ClassLabel classes",
f"{ds_name} has {num_classes} classes while first dataset has",
f"{expected_classes}. Prompting datasets separately.",
)
return False
return True

def explode(self) -> list["PromptConfig"]:
"""Explode the config into a list of configs, one for each dataset."""
copies = []
Expand Down Expand Up @@ -179,7 +113,6 @@ def load_prompts(
stream: bool = False,
rank: int = 0,
world_size: int = 1,
combined_template_output_path: str = "",
) -> Iterator[dict]:
"""Load a dataset full of prompts generated from the specified dataset.
Expand All @@ -198,12 +131,7 @@ def load_prompts(
An iterable of prompt dictionaries.
"""
ds_name, _, config_name = ds_string.partition(" ")

prompter = None
if combined_template_output_path and exists(combined_template_output_path):
prompter = DatasetTemplates("combined_templates", combined_template_output_path)
else:
prompter = DatasetTemplates(ds_name, config_name)
prompter = DatasetTemplates(ds_name, config_name)

ds_dict = assert_type(
dict, load_dataset(ds_name, config_name or None, streaming=stream)
Expand Down
9 changes: 0 additions & 9 deletions elk/promptsource/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,6 @@ def delete_folder(self) -> None:
if len(os.listdir(base_folder)) == 0:
rmtree(base_folder)

def merge_templates_from(self, src: "DatasetTemplates"):
"""
Merge templates from src.
"""
for template in src.templates.values():
template_id = str(uuid.uuid4())
self.templates[template_id] = template
self.sync_mapping()

def __getitem__(self, template_key: str) -> "Template":
return self.templates[self.name_to_id_mapping[template_key]]

Expand Down

0 comments on commit 7d7c175

Please sign in to comment.