Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spar mt prompt invar #184

Merged
merged 55 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
7b68f5f
add boolq_pt template and christykoh as included user
ChristyKoh Mar 9, 2023
edfdd6d
add templates.yaml for boolqpt
ChristyKoh Mar 14, 2023
d285798
pt yaml
reaganjlee Mar 26, 2023
6c91be2
Merge remote-tracking branch 'origin/main' into spar_boolq_pt
reaganjlee Apr 5, 2023
0ff1609
add ag_news template, translated to pt
ChristyKoh Apr 5, 2023
7180a64
add ag_news template, translated to pt
ChristyKoh Apr 5, 2023
b8f5e8b
save eval runs to separate subfolders by target dataset
ChristyKoh Apr 7, 2023
51fed5d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
9cebbdd
Merge branch 'main' into spar_pt
ChristyKoh Apr 7, 2023
6be825a
Merge branch 'spar_boolq_pt' into spar_pt
ChristyKoh Apr 7, 2023
2660844
eval multiple datasets
ChristyKoh Apr 7, 2023
f35a6a8
Merge branch 'eval_dirs' of github.com:EleutherAI/elk into eval_dirs
ChristyKoh Apr 7, 2023
b230209
change prompt answer chouces to portuguese
ChristyKoh Apr 7, 2023
74c9915
add imdb_pt template
ChristyKoh Apr 11, 2023
85fd9e4
implement prompt sharing, generate combined templates.yaml
ChristyKoh Apr 11, 2023
8383f26
fix num templates logic
ChristyKoh Apr 12, 2023
df41ab4
fix pt answer choice
ChristyKoh Apr 12, 2023
c03742d
Merge branch 'spar_pt' of github.com:EleutherAI/elk into spar_mt_prom…
ChristyKoh Apr 12, 2023
e536aa6
Merge branch 'main' of github.com:EleutherAI/elk into main
ChristyKoh Apr 12, 2023
208200d
Merge branch 'main' into spar_pt
ChristyKoh Apr 12, 2023
a132f40
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
9f3f218
Merge branch 'spar_pt' of github.com:EleutherAI/elk into spar_mt_prom…
ChristyKoh Apr 12, 2023
c71bf1c
remove empty prompt_dataset file
ChristyKoh Apr 12, 2023
1ec5787
fix empty prompters bug
ChristyKoh Apr 12, 2023
66c7a6b
fix multiclass label bug
ChristyKoh Apr 12, 2023
d91acac
move prompt combination to PromptConfig post_init logic
ChristyKoh Apr 12, 2023
89b2346
Merge branch 'main' of github.com:EleutherAI/elk into main
ChristyKoh Apr 12, 2023
715bba8
Merge branch 'main' into spar_mt_prompt_invar
ChristyKoh Apr 12, 2023
0c2f5c4
fix refactor bugs, runnable state
ChristyKoh Apr 12, 2023
066cd44
rewrite template merging, regenerate prompter every run
ChristyKoh Apr 12, 2023
846b78c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
1aecd7e
line len fixes
ChristyKoh Apr 12, 2023
b7bbee0
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 12, 2023
b0c0f63
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
34bc364
Merge remote-tracking branch 'origin/main' into spar_mt_prompt_invar
ChristyKoh Apr 12, 2023
2da069a
update README with prompt invariance argument
ChristyKoh Apr 12, 2023
4c6d344
fix bugs, add dataset col checks
ChristyKoh Apr 12, 2023
53d186b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
b78355f
fix prompter init typing
ChristyKoh Apr 12, 2023
3486ded
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 12, 2023
1975410
Update README.md
lauritowal Apr 12, 2023
8a5fb0d
try to fix typing again
ChristyKoh Apr 12, 2023
6d46f7e
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 12, 2023
a7f5a8b
assert datasettemplates type
ChristyKoh Apr 12, 2023
ba090b7
Merge remote-tracking branch 'origin/main' into eval_dirs
ChristyKoh Apr 13, 2023
7af1a1b
bugfix to run eval separately on each dataset
ChristyKoh Apr 13, 2023
74551fa
add combine_evals flag to differentiate a multi dataset eval from a b…
ChristyKoh Apr 13, 2023
0b3d3c9
Merge branch 'main' of github.com:EleutherAI/elk into main
ChristyKoh Apr 19, 2023
d8cee8b
Merge branch 'main' into eval_dirs
ChristyKoh Apr 19, 2023
54aa710
Merge branch 'eval_dirs' into spar_mt_prompt_invar, separate combinin…
ChristyKoh Apr 19, 2023
f7a4713
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2023
6e2f54c
define ds_name
ChristyKoh Apr 19, 2023
3b0765d
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 19, 2023
5f0f32a
fix ds_name bug
ChristyKoh Apr 20, 2023
8fa07b4
Update README.md
lauritowal Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ For prompt invariance across multiple datasets, use the `--combined_template_pat
elk elicit bigscience/bloomz-560m christykoh/ag_news_pt ag_news --combined_template_path=spar_w/ag_news
```

The following runs `elicit` on the Cartesian product of the listed models and datasets, storing it in a special folder ELK_DIR/sweeps/<memorable_name>. Moreover, `--add_pooled` adds an additional dataset that pools all of the datasets together.

```bash
elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled
```

## Caching

The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`.
Expand Down
14 changes: 11 additions & 3 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,21 @@ class Eval(Serializable):
num_gpus: int = -1
out_dir: Path | None = None
skip_supervised: bool = False
combine_evals: bool = False

def execute(self):
datasets = self.data.prompts.datasets

transfer_dir = elk_reporter_dir() / self.source / "transfer_eval"

for dataset in self.data.prompts.datasets:
run = Evaluate(cfg=self, out_dir=transfer_dir / dataset)
run.evaluate()
if self.combine_evals:
run = Evaluate(cfg=self, out_dir=transfer_dir / ", ".join(datasets))
else:
# eval on each dataset separately
for dataset in datasets:
self.data.prompts.datasets = [dataset]
run = Evaluate(cfg=self, out_dir=transfer_dir / dataset)
run.evaluate()


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def extract_hiddens(
stream=p_cfg.stream,
rank=rank,
world_size=world_size,
combined_template_path=cfg.prompts.combined_template_path,
combined_template_output_path=cfg.prompts.combined_template_output_path,
) # this dataset is already sharded, but hasn't been truncated to max_examples

model = instantiate_model(
Expand Down
178 changes: 108 additions & 70 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os.path import exists
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -47,8 +48,8 @@ class PromptConfig(Serializable):
-1.
seed: The seed to use for prompt randomization. Defaults to 42.
stream: Whether to stream the dataset from the Internet. Defaults to False.
combined_template_path: Path to save a combined template file to, when testing
prompt invariance across multiple datasets, and will be interpreted as a
combined_template_output_path: Path to save a combined template file to, when
applying prompt invariance across multiple datasets. Interpreted as a
subpath of `combined_paths` in the templates dir. Defaults to empty string.
"""

Expand All @@ -61,7 +62,7 @@ class PromptConfig(Serializable):
num_variants: int = -1
seed: int = 42
stream: bool = False
combined_template_path: str = ""
combined_template_output_path: str = ""

def __post_init__(self):
if len(self.max_examples) > 2:
Expand All @@ -72,63 +73,99 @@ def __post_init__(self):
if not self.max_examples:
self.max_examples = [int(1e100)]

self.combine_templates()

# Broadcast the limit to all splits
if len(self.max_examples) == 1:
self.max_examples *= 2

# Combining prompts
combined_prompter = None
if self.combined_template_path:
print(
"Copying templates across datasets to combined_templates/ "
+ f"{self.combined_template_path}/templates.yaml"
# Broadcast the dataset name to all data_dirs and label_columns
if len(self.data_dirs) == 1:
self.data_dirs *= len(self.datasets)
elif self.data_dirs and len(self.data_dirs) != len(self.datasets):
raise ValueError(
"data_dirs should be a list of length 0, 1, or len(datasets),"
f" but got {len(self.data_dirs)}"
)
combined_prompter = DatasetTemplates(
"combined_templates", self.combined_template_path

if len(self.label_columns) == 1:
self.label_columns *= len(self.datasets)
elif self.label_columns and len(self.label_columns) != len(self.datasets):
raise ValueError(
"label_columns should be a list of length 0, 1, or len(datasets),"
f" but got {len(self.label_columns)}"
)

def combine_templates(self):
if not self.combined_template_output_path:
return

print(
"Copying templates across datasets to combined_templates/ "
+ f"{self.combined_template_output_path}/templates.yaml"
)
combined_prompter = DatasetTemplates(
"combined_templates", self.combined_template_output_path
)
combined_prompter.templates = {}
ref_ds_builder = None
for i, ds_string in enumerate(self.datasets):
ds_name, _, config_name = ds_string.partition(" ")
ds_builder = load_dataset_builder(ds_name, config_name or None)
if i == 0:
# Set first dataset as reference
ref_ds_builder = ds_builder
elif not self.verify_cols(ds_builder, ref_ds_builder):
return

# Once verified, merge templates.
prompter = DatasetTemplates(ds_name, config_name)
combined_prompter.merge_templates_from(prompter)
print("Total number of templates: ", len(combined_prompter.templates))
combined_prompter.write_to_file()
print(
"Saved to promptsource/templates/combined_templates/"
+ f"{self.combined_template_output_path}.yaml"
)

def verify_cols(self, ds_builder, ref_ds_builder) -> bool:
'''Verify that number of features and number of classes for ClassLabel
match the expected values.
'''
expected_features = len(ref_ds_builder.info.features)
expected_classes = ref_ds_builder.info.features["label"].num_classes
num_features = len(ds_builder.info.features)
num_classes = ds_builder.info.features["label"].num_classes
if expected_features > 0 and num_features != expected_features:
print(
"WARNING: Datasets do not have the same number of features;",
f"{ds_name} has {num_features} features while first dataset has",
f"{expected_features}. Prompting datasets separately.",
)
combined_prompter.templates = {}
prev_num_features = 0
prev_num_label_classes = 0
for ds_string in self.datasets:
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)

# Verify that number of features and number of classes for ClassLabel
# are the same across datasets.
ds_builder = load_dataset_builder(ds_name, config_name or None)
num_features = len(ds_builder.info.features)
if prev_num_features > 0 and num_features != prev_num_features:
print(
"WARNING: Datasets do not have the same number of features;",
f"{ds_name} has {num_features} features while prev has",
f"{prev_num_features}. Prompting datasets separately.",
)
self.combined_template_path = ""
break
prev_num_features = num_features
num_classes = ds_builder.info.features["label"].num_classes
if prev_num_label_classes > 0 and num_classes != prev_num_label_classes:
print(
"WARNING: Datasets do not have the same number of ClassLabel",
f"classes; {ds_name} has {num_classes} classes while prev has",
f"{prev_num_label_classes}. Prompting datasets separately.",
)
self.combined_template_path = ""
break
prev_num_label_classes = num_classes

# Once verified, merge templates.
combined_prompter.merge_templates_from(prompter)

# Write to file if successfully merged all prompts.
if self.combined_template_path:
prompter = assert_type(DatasetTemplates, combined_prompter)
print("Total number of templates: ", len(prompter.templates))
prompter.write_to_file()
return False
if expected_classes > 0 and num_classes != expected_classes:
print(
"Saved to promptsource/templates/combined_templates/"
+ f"{self.combined_template_path}.yaml"
"WARNING: Datasets do not have the same number of ClassLabel classes",
f"{ds_name} has {num_classes} classes while first dataset has",
f"{expected_classes}. Prompting datasets separately.",
)
return False
return True

def explode(self) -> list["PromptConfig"]:
"""Explode the config into a list of configs, one for each dataset."""
copies = []

for ds, data_dir, col in zip_longest(
self.datasets, self.data_dirs, self.label_columns
):
copy = deepcopy(self)
copy.datasets = [ds]
copy.data_dirs = [data_dir] if data_dir else []
copy.label_columns = [col] if col else []
copies.append(copy)

return copies


def load_prompts(
Expand All @@ -142,7 +179,7 @@ def load_prompts(
stream: bool = False,
rank: int = 0,
world_size: int = 1,
combined_template_path: str = "",
combined_template_output_path: str = "",
) -> Iterator[dict]:
"""Load a dataset full of prompts generated from the specified dataset.

Expand All @@ -161,16 +198,20 @@ def load_prompts(
An iterable of prompt dictionaries.
"""
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)

# First load the datasets and prompters. We need to know the minimum number of
# templates for any dataset in order to make sure we don't run out of prompts.
for ds_string in dataset_strings:
ds_name, _, config_name = ds_string.partition(" ")

if combined_template_path == "":
prompter = DatasetTemplates(ds_name, config_name)
prompters.append(DatasetTemplates(ds_name, config_name))

prompter = None
if combined_template_output_path and exists(combined_template_output_path):
prompter = DatasetTemplates(
"combined_templates", combined_template_output_path
)
else:
prompter = DatasetTemplates(ds_name, config_name)

ds_dict = assert_type(
dict, load_dataset(ds_name, config_name or None, streaming=stream)
)
train_name, val_name = select_train_val_splits(ds_dict)
split_name = val_name if split_type == "val" else train_name

ds = ds_dict[split_name].shuffle(seed=seed)
train_ds = ds_dict[train_name].shuffle(seed=seed)
Expand All @@ -181,14 +222,11 @@ def load_prompts(

ds = ds.to_iterable_dataset().cast(ds.features)

if combined_template_path:
combined_prompter = DatasetTemplates(
"combined_templates", combined_template_path
)
prompters = [combined_prompter] * len(dataset_strings)

min_num_templates = min(len(prompter.templates) for prompter in prompters)
elif world_size > 1:
# This prints to stdout which is slightly annoying
ds = split_dataset_by_node(dataset=ds, rank=rank, world_size=world_size)

num_templates = len(prompter.templates)
num_variants = (
num_templates if num_variants == -1 else min(num_variants, num_templates)
)
Expand Down