Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spar mt prompt invar #184

Merged
merged 55 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
7b68f5f
add boolq_pt template and christykoh as included user
ChristyKoh Mar 9, 2023
edfdd6d
add templates.yaml for boolqpt
ChristyKoh Mar 14, 2023
d285798
pt yaml
reaganjlee Mar 26, 2023
6c91be2
Merge remote-tracking branch 'origin/main' into spar_boolq_pt
reaganjlee Apr 5, 2023
0ff1609
add ag_news template, translated to pt
ChristyKoh Apr 5, 2023
7180a64
add ag_news template, translated to pt
ChristyKoh Apr 5, 2023
b8f5e8b
save eval runs to separate subfolders by target dataset
ChristyKoh Apr 7, 2023
51fed5d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
9cebbdd
Merge branch 'main' into spar_pt
ChristyKoh Apr 7, 2023
6be825a
Merge branch 'spar_boolq_pt' into spar_pt
ChristyKoh Apr 7, 2023
2660844
eval multiple datasets
ChristyKoh Apr 7, 2023
f35a6a8
Merge branch 'eval_dirs' of github.com:EleutherAI/elk into eval_dirs
ChristyKoh Apr 7, 2023
b230209
change prompt answer chouces to portuguese
ChristyKoh Apr 7, 2023
74c9915
add imdb_pt template
ChristyKoh Apr 11, 2023
85fd9e4
implement prompt sharing, generate combined templates.yaml
ChristyKoh Apr 11, 2023
8383f26
fix num templates logic
ChristyKoh Apr 12, 2023
df41ab4
fix pt answer choice
ChristyKoh Apr 12, 2023
c03742d
Merge branch 'spar_pt' of github.com:EleutherAI/elk into spar_mt_prom…
ChristyKoh Apr 12, 2023
e536aa6
Merge branch 'main' of github.com:EleutherAI/elk into main
ChristyKoh Apr 12, 2023
208200d
Merge branch 'main' into spar_pt
ChristyKoh Apr 12, 2023
a132f40
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
9f3f218
Merge branch 'spar_pt' of github.com:EleutherAI/elk into spar_mt_prom…
ChristyKoh Apr 12, 2023
c71bf1c
remove empty prompt_dataset file
ChristyKoh Apr 12, 2023
1ec5787
fix empty prompters bug
ChristyKoh Apr 12, 2023
66c7a6b
fix multiclass label bug
ChristyKoh Apr 12, 2023
d91acac
move prompt combination to PromptConfig post_init logic
ChristyKoh Apr 12, 2023
89b2346
Merge branch 'main' of github.com:EleutherAI/elk into main
ChristyKoh Apr 12, 2023
715bba8
Merge branch 'main' into spar_mt_prompt_invar
ChristyKoh Apr 12, 2023
0c2f5c4
fix refactor bugs, runnable state
ChristyKoh Apr 12, 2023
066cd44
rewrite template merging, regenerate prompter every run
ChristyKoh Apr 12, 2023
846b78c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
1aecd7e
line len fixes
ChristyKoh Apr 12, 2023
b7bbee0
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 12, 2023
b0c0f63
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
34bc364
Merge remote-tracking branch 'origin/main' into spar_mt_prompt_invar
ChristyKoh Apr 12, 2023
2da069a
update README with prompt invariance argument
ChristyKoh Apr 12, 2023
4c6d344
fix bugs, add dataset col checks
ChristyKoh Apr 12, 2023
53d186b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
b78355f
fix prompter init typing
ChristyKoh Apr 12, 2023
3486ded
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 12, 2023
1975410
Update README.md
lauritowal Apr 12, 2023
8a5fb0d
try to fix typing again
ChristyKoh Apr 12, 2023
6d46f7e
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 12, 2023
a7f5a8b
assert datasettemplates type
ChristyKoh Apr 12, 2023
ba090b7
Merge remote-tracking branch 'origin/main' into eval_dirs
ChristyKoh Apr 13, 2023
7af1a1b
bugfix to run eval separately on each dataset
ChristyKoh Apr 13, 2023
74551fa
add combine_evals flag to differentiate a multi dataset eval from a b…
ChristyKoh Apr 13, 2023
0b3d3c9
Merge branch 'main' of github.com:EleutherAI/elk into main
ChristyKoh Apr 19, 2023
d8cee8b
Merge branch 'main' into eval_dirs
ChristyKoh Apr 19, 2023
54aa710
Merge branch 'eval_dirs' into spar_mt_prompt_invar, separate combinin…
ChristyKoh Apr 19, 2023
f7a4713
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2023
6e2f54c
define ds_name
ChristyKoh Apr 19, 2023
3b0765d
Merge branch 'spar_mt_prompt_invar' of github.com:EleutherAI/elk into…
ChristyKoh Apr 19, 2023
5f0f32a
fix ds_name bug
ChristyKoh Apr 20, 2023
8fa07b4
Update README.md
lauritowal Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
implement prompt sharing, generate combined templates.yaml
  • Loading branch information
ChristyKoh committed Apr 11, 2023
commit 85fd9e40b56ff6ba4b103d48c6962c0b24b12b6b
2 changes: 2 additions & 0 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Extract(Serializable):
token_loc: Literal["first", "last", "mean"] = "last"
min_gpu_mem: Optional[int] = None
num_gpus: int = -1
combined_prompter_path: Optional[str] = None # if template file does not exist, combine from datasets and save to this path

def __post_init__(self, layer_stride: int):
if self.layers and layer_stride > 1:
Expand Down Expand Up @@ -98,6 +99,7 @@ def extract_hiddens(
stream=cfg.prompts.stream,
rank=rank,
world_size=world_size,
combined_prompter_path=cfg.combined_prompter_path
) # this dataset is already sharded, but hasn't been truncated to max_examples

# AutoModel should do the right thing here in nearly all cases. We don't actually
Expand Down
31 changes: 29 additions & 2 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os.path import exists
from dataclasses import dataclass
from random import Random
from typing import Any, Iterator, Literal, Optional
Expand Down Expand Up @@ -40,7 +41,7 @@ class PromptConfig(Serializable):
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot. Defaults to 0.
num_variants: The number of prompt templates to apply to each predicate upon
call to __getitem__. Use -1 to apply all available templates. Defaults to 1.
call to __getitem__. Use -1 to apply all available templates. Defaults to -1.
seed: The seed to use for prompt randomization. Defaults to 42.
stream: Whether to stream the dataset from the Internet. Defaults to False.
"""
Expand Down Expand Up @@ -78,6 +79,7 @@ def load_prompts(
stream: bool = False,
rank: int = 0,
world_size: int = 1,
combined_prompter_path: str = ""
) -> Iterator[dict]:
"""Load a dataset full of prompts generated from the specified datasets.

Expand All @@ -100,10 +102,24 @@ def load_prompts(
train_datasets = []
rng = Random(seed)

# If combined template is not empty and does not exist as a file yet, need to aggregate
# Init/create a new file for combining templates
combined_prompter = None
if combined_prompter_path:
print("Combining templates into shared prompter.")
combined_prompter = DatasetTemplates("combined_templates", combined_prompter_path)
# should_aggregate_templates = (combined_prompter and not exists(combined_prompter.yaml_path))
# print("should aggregate: ", should_aggregate_templates)

# First load the datasets and prompters. We need to know the minimum number of
# templates for any dataset in order to make sure we don't run out of prompts.
for ds_string in dataset_strings:
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)
# Populate combined prompter with templates from different datasets
# if should_aggregate_templates:
combined_prompter.templates.update(prompter.get_templates_with_new_uuids())
print("len of prompter templates is ", len(combined_prompter.templates))
prompters.append(DatasetTemplates(ds_name, config_name))

ds_dict = assert_type(
Expand Down Expand Up @@ -136,11 +152,22 @@ def load_prompts(
train_datasets.append(train_ds)

min_num_templates = min(len(prompter.templates) for prompter in prompters)
# if should_aggregate_templates:

if combined_prompter:
# save combined templates to yaml file
print("saving aggregate templates")
combined_prompter.sync_mapping()
combined_prompter.write_to_file()
min_num_templates = len(combined_prompter.templates)
print("length of combined_prompter templates is ", min_num_templates)

num_variants = (
min_num_templates
if num_variants == -1
else min(num_variants, min_num_templates)
)
print()
ChristyKoh marked this conversation as resolved.
Show resolved Hide resolved
assert num_variants > 0
if rank == 0:
print(f"Using {num_variants} variants of each prompt")
Expand Down Expand Up @@ -179,7 +206,7 @@ def load_prompts(
label_column=label_column,
num_classes=num_classes,
num_variants=num_variants,
prompter=prompter,
prompter=prompter if not combined_prompter else combined_prompter,
rng=rng,
fewshot_iter=fewshot_iter,
)
Expand Down
10 changes: 10 additions & 0 deletions elk/promptsource/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,16 @@ def delete_folder(self) -> None:
if len(os.listdir(base_folder)) == 0:
rmtree(base_folder)

def get_templates_with_new_uuids(self) -> dict:
"""
Generate new uuids for templates, used when merging template datasets.
"""
new_templates = {}
for template in self.templates.values():
template.id = str(uuid.uuid4())
new_templates[template.id] = template
return new_templates

def __getitem__(self, template_key: str) -> "Template":
return self.templates[self.name_to_id_mapping[template_key]]

Expand Down