Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spar mt prompt invar #184

Merged
merged 55 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
7b68f5f
add boolq_pt template and christykoh as included user
ChristyKoh Mar 9, 2023
edfdd6d
add templates.yaml for boolqpt
ChristyKoh Mar 14, 2023
d285798
pt yaml
reaganjlee Mar 26, 2023
6c91be2
Merge remote-tracking branch 'origin/main' into spar_boolq_pt
reaganjlee Apr 5, 2023
0ff1609
add ag_news template, translated to pt
ChristyKoh Apr 5, 2023
7180a64
add ag_news template, translated to pt
ChristyKoh Apr 5, 2023
b8f5e8b
save eval runs to separate subfolders by target dataset
ChristyKoh Apr 7, 2023
51fed5d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
9cebbdd
Merge branch 'main' into spar_pt
ChristyKoh Apr 7, 2023
6be825a
Merge branch 'spar_boolq_pt' into spar_pt
ChristyKoh Apr 7, 2023
2660844
eval multiple datasets
ChristyKoh Apr 7, 2023
f35a6a8
Merge branch 'eval_dirs' of github.com:EleutherAI/elk into eval_dirs
ChristyKoh Apr 7, 2023
b230209
change prompt answer chouces to portuguese
ChristyKoh Apr 7, 2023
74c9915
add imdb_pt template
ChristyKoh Apr 11, 2023
85fd9e4
implement prompt sharing, generate combined templates.yaml
ChristyKoh Apr 11, 2023
8383f26
fix num templates logic
ChristyKoh Apr 12, 2023
df41ab4
fix pt answer choice
ChristyKoh Apr 12, 2023
c03742d
Merge branch 'spar_pt' of github.com:EleutherAI/elk into spar_mt_prom…
ChristyKoh Apr 12, 2023
e536aa6
Merge branch 'main' of github.com:EleutherAI/elk into main
ChristyKoh Apr 12, 2023
208200d
Merge branch 'main' into spar_pt
ChristyKoh Apr 12, 2023
a132f40
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
9f3f218
Merge branch 'spar_pt' of github.com:EleutherAI/elk into spar_mt_prom…
ChristyKoh Apr 12, 2023
c71bf1c
remove empty prompt_dataset file
ChristyKoh Apr 12, 2023
1ec5787
fix empty prompters bug
ChristyKoh Apr 12, 2023
66c7a6b
fix multiclass label bug
ChristyKoh Apr 12, 2023
d91acac
move prompt combination to PromptConfig post_init logic
ChristyKoh Apr 12, 2023
89b2346
Merge branch 'main' of github.com:EleutherAI/elk into main
ChristyKoh Apr 12, 2023
715bba8
Merge branch 'main' into spar_mt_prompt_invar
ChristyKoh Apr 12, 2023
0c2f5c4
fix refactor bugs, runnable state
ChristyKoh Apr 12, 2023
066cd44
rewrite template merging, regenerate prompter every run
ChristyKoh Apr 12, 2023
846b78c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
1aecd7e
line len fixes
ChristyKoh Apr 12, 2023
b7bbee0
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 12, 2023
b0c0f63
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
34bc364
Merge remote-tracking branch 'origin/main' into spar_mt_prompt_invar
ChristyKoh Apr 12, 2023
2da069a
update README with prompt invariance argument
ChristyKoh Apr 12, 2023
4c6d344
fix bugs, add dataset col checks
ChristyKoh Apr 12, 2023
53d186b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
b78355f
fix prompter init typing
ChristyKoh Apr 12, 2023
3486ded
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 12, 2023
1975410
Update README.md
lauritowal Apr 12, 2023
8a5fb0d
try to fix typing again
ChristyKoh Apr 12, 2023
6d46f7e
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 12, 2023
a7f5a8b
assert datasettemplates type
ChristyKoh Apr 12, 2023
ba090b7
Merge remote-tracking branch 'origin/main' into eval_dirs
ChristyKoh Apr 13, 2023
7af1a1b
bugfix to run eval separately on each dataset
ChristyKoh Apr 13, 2023
74551fa
add combine_evals flag to differentiate a multi dataset eval from a b…
ChristyKoh Apr 13, 2023
0b3d3c9
Merge branch 'main' of github.com:EleutherAI/elk into main
ChristyKoh Apr 19, 2023
d8cee8b
Merge branch 'main' into eval_dirs
ChristyKoh Apr 19, 2023
54aa710
Merge branch 'eval_dirs' into spar_mt_prompt_invar, separate combinin…
ChristyKoh Apr 19, 2023
f7a4713
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2023
6e2f54c
define ds_name
ChristyKoh Apr 19, 2023
3b0765d
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 19, 2023
5f0f32a
fix ds_name bug
ChristyKoh Apr 20, 2023
8fa07b4
Update README.md
lauritowal Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ The following command will evaluate the probe from the run naughty-northcutt on
elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb
```

For prompt invariance across multiple datasets, use the `--combined_template_output_path` command line argument, which will create a new `templates.yaml` file with templates from all the datasets.
```bash
elk elicit bigscience/bloomz-560m christykoh/ag_news_pt ag_news --combined_template_output_path=spar_w/ag_news
```

The following runs `elicit` on the Cartesian product of the listed models and datasets, storing it in a special folder ELK_DIR/sweeps/<memorable_name>. Moreover, `--add_pooled` adds an additional dataset that pools all of the datasets together.

```bash
Expand Down
14 changes: 11 additions & 3 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,21 @@ class Eval(Serializable):
num_gpus: int = -1
out_dir: Path | None = None
skip_supervised: bool = False
combine_evals: bool = False

def execute(self):
datasets = self.data.prompts.datasets

transfer_dir = elk_reporter_dir() / self.source / "transfer_eval"

for dataset in self.data.prompts.datasets:
run = Evaluate(cfg=self, out_dir=transfer_dir / dataset)
run.evaluate()
if self.combine_evals:
run = Evaluate(cfg=self, out_dir=transfer_dir / ", ".join(datasets))
else:
# eval on each dataset separately
for dataset in datasets:
self.data.prompts.datasets = [dataset]
run = Evaluate(cfg=self, out_dir=transfer_dir / dataset)
run.evaluate()


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def extract_hiddens(
stream=p_cfg.stream,
rank=rank,
world_size=world_size,
combined_template_output_path=cfg.prompts.combined_template_output_path,
) # this dataset is already sharded, but hasn't been truncated to max_examples

model = instantiate_model(
Expand Down Expand Up @@ -283,6 +284,7 @@ def get_splits() -> SplitDict:

model_cfg = AutoConfig.from_pretrained(cfg.model)

# Retrieve info, used to get splits
ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)

Expand Down
Empty file removed elk/extraction/prompt_dataset.py
Empty file.
80 changes: 76 additions & 4 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from copy import deepcopy
from dataclasses import dataclass
from itertools import zip_longest
from os.path import exists
from random import Random
from typing import Any, Iterator, Literal, Optional

from datasets import (
Dataset,
Features,
load_dataset,
load_dataset_builder,
)
from datasets.distributed import split_dataset_by_node
from simple_parsing.helpers import Serializable, field
Expand All @@ -27,8 +29,9 @@
class PromptConfig(Serializable):
"""
Args:
dataset: List of space-delimited names of the HuggingFace dataset to use, e.g.
`"super_glue boolq"` or `"imdb"`.
datasets: List of space-delimited names of the HuggingFace datasets to use, e.g.
[`"super_glue boolq", "imdb"]`.
balance: Whether to force class balance in the dataset using undersampling.
data_dir: The directory to use for caching the dataset. Defaults to
`~/.cache/huggingface/datasets`.
label_column: The column containing the labels. By default, we infer this from
Expand All @@ -41,9 +44,13 @@ class PromptConfig(Serializable):
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot. Defaults to 0.
num_variants: The number of prompt templates to apply to each predicate upon
call to __getitem__. Use -1 to apply all available templates. Defaults to 1.
call to __getitem__. Use -1 to apply all available templates. Defaults to
-1.
seed: The seed to use for prompt randomization. Defaults to 42.
stream: Whether to stream the dataset from the Internet. Defaults to False.
combined_template_output_path: Path to save a combined template file to, when
applying prompt invariance across multiple datasets. Interpreted as a
subpath of `combined_paths` in the templates dir. Defaults to empty string.
"""

datasets: list[str] = field(positional=True)
Expand All @@ -55,6 +62,7 @@ class PromptConfig(Serializable):
num_variants: int = -1
seed: int = 42
stream: bool = False
combined_template_output_path: str = ""

def __post_init__(self):
if len(self.max_examples) > 2:
Expand All @@ -65,6 +73,8 @@ def __post_init__(self):
if not self.max_examples:
self.max_examples = [int(1e100)]

self.combine_templates()

# Broadcast the limit to all splits
if len(self.max_examples) == 1:
self.max_examples *= 2
Expand All @@ -86,6 +96,62 @@ def __post_init__(self):
f" but got {len(self.label_columns)}"
)

def combine_templates(self):
if not self.combined_template_output_path:
return

print(
"Copying templates across datasets to combined_templates/ "
+ f"{self.combined_template_output_path}/templates.yaml"
)
combined_prompter = DatasetTemplates(
"combined_templates", self.combined_template_output_path
)
combined_prompter.templates = {}
ref_ds_builder = None
for i, ds_string in enumerate(self.datasets):
ds_name, _, config_name = ds_string.partition(" ")
ds_builder = load_dataset_builder(ds_name, config_name or None)
if i == 0:
# Set first dataset as reference
ref_ds_builder = ds_builder
elif not self.verify_cols(ref_ds_builder, ds_builder, ds_name):
return

# Once verified, merge templates.
prompter = DatasetTemplates(ds_name, config_name)
combined_prompter.merge_templates_from(prompter)
print("Total number of templates: ", len(combined_prompter.templates))
combined_prompter.write_to_file()
print(
"Saved to promptsource/templates/combined_templates/"
+ f"{self.combined_template_output_path}.yaml"
)

def verify_cols(self, ref_ds_builder, ds_builder, ds_name) -> bool:
"""Verify that number of features and number of classes for ClassLabel
match the expected values.
"""
expected_features = len(ref_ds_builder.info.features)
expected_classes = ref_ds_builder.info.features["label"].num_classes
num_features = len(ds_builder.info.features)
num_classes = ds_builder.info.features["label"].num_classes
if expected_features > 0 and num_features != expected_features:
print(
"WARNING: Datasets do not have the same number of features;",
f"{ds_name} has {num_features} features while first dataset has",
f"{expected_features}. Prompting datasets separately.",
)
return False
if expected_classes > 0 and num_classes != expected_classes:
print(
"WARNING: Datasets do not have the same number of ClassLabel classes",
f"{ds_name} has {num_classes} classes while first dataset has",
f"{expected_classes}. Prompting datasets separately.",
)
return False
return True

def explode(self) -> list["PromptConfig"]:
"""Explode the config into a list of configs, one for each dataset."""
copies = []
Expand Down Expand Up @@ -113,6 +179,7 @@ def load_prompts(
stream: bool = False,
rank: int = 0,
world_size: int = 1,
combined_template_output_path: str = "",
) -> Iterator[dict]:
"""Load a dataset full of prompts generated from the specified dataset.

Expand All @@ -131,7 +198,12 @@ def load_prompts(
An iterable of prompt dictionaries.
"""
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)

prompter = None
if combined_template_output_path and exists(combined_template_output_path):
prompter = DatasetTemplates("combined_templates", combined_template_output_path)
else:
prompter = DatasetTemplates(ds_name, config_name)

ds_dict = assert_type(
dict, load_dataset(ds_name, config_name or None, streaming=stream)
Expand Down
9 changes: 9 additions & 0 deletions elk/promptsource/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,15 @@ def delete_folder(self) -> None:
if len(os.listdir(base_folder)) == 0:
rmtree(base_folder)

def merge_templates_from(self, src: "DatasetTemplates"):
"""
Merge templates from src.
"""
for template in src.templates.values():
template_id = str(uuid.uuid4())
self.templates[template_id] = template
self.sync_mapping()

def __getitem__(self, template_key: str) -> "Template":
return self.templates[self.name_to_id_mapping[template_key]]

Expand Down