Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi datasets #123

Merged
merged 46 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
681698d
add multiple datasets support
Benw8888 Mar 9, 2023
ac1b9f1
Merge branch 'main' of github.com:EleutherAI/elk into multi-datasets
Benw8888 Mar 9, 2023
b864c77
train_reporter works on a list of layers now
Benw8888 Mar 10, 2023
7d7d97c
changing printed layer names
Benw8888 Mar 10, 2023
4fe61e9
fixed concatenation bug
Benw8888 Mar 11, 2023
fe61d67
minor edits
Benw8888 Mar 13, 2023
74da878
fixed pyright issues
Benw8888 Mar 13, 2023
569ef05
Merge branch 'main' of github.com:EleutherAI/elk into multi-datasets
Benw8888 Mar 14, 2023
b62b679
Merge branch 'main' into multi-datasets
norabelrose Mar 20, 2023
fe94c22
Fix tests
norabelrose Mar 20, 2023
bba24d8
Now working sorta
norabelrose Mar 22, 2023
03ba6e0
Skip slow BalancedBatchSampler test
norabelrose Mar 22, 2023
15ab351
Slightly relax test_output_is_roughly_balanced
norabelrose Mar 22, 2023
a80369e
Make BalancedSampler deterministic
norabelrose Mar 22, 2023
d304ab3
InitVar
norabelrose Mar 22, 2023
761c82d
Support multi class again
norabelrose Mar 22, 2023
f29743b
Fix naming issue
norabelrose Mar 22, 2023
b7b7e23
Support few shot prompts
norabelrose Mar 23, 2023
1afb563
Merge branch 'main' into multi-datasets
norabelrose Mar 23, 2023
225d4c7
fix multiclass labels
AlexTMallen Mar 23, 2023
9368dc8
Merge branch 'multi-datasets' of github.com:EleutherAI/elk into multi…
AlexTMallen Mar 23, 2023
a858b65
Merge branch 'main' into multi-datasets
norabelrose Mar 24, 2023
5dc2ec6
Merge branch 'multi-datasets' of github.com:EleutherAI/elk into multi…
norabelrose Mar 24, 2023
b1b95e5
Fix dumb part of test failures
norabelrose Mar 25, 2023
ee3911e
Fix assert_allclose warning
norabelrose Mar 25, 2023
a55b3de
Switch to torch.testing.assert_close in EigenReporter test
norabelrose Mar 25, 2023
44dc25c
Shuffle load_prompts output by default
norabelrose Mar 25, 2023
93d8d87
Fix smoke test failure
norabelrose Mar 25, 2023
fad4d74
Remove debug prints
AlexTMallen Mar 25, 2023
0a054f4
Remove more debug print statements
AlexTMallen Mar 25, 2023
177eec2
make min_memory usable; broadcast mmax_examples in __post_init__
AlexTMallen Mar 26, 2023
3a762b0
prompt loading refactor to enable better streaming
AlexTMallen Mar 26, 2023
f66c054
remove shuffle arg
AlexTMallen Mar 26, 2023
d3d87fc
remove unused @dataclass
lauritowal Mar 26, 2023
3d08147
merge
lauritowal Mar 27, 2023
c9a43e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 27, 2023
94290aa
add concatenated_layer_offset to eval
lauritowal Mar 27, 2023
f9298e4
Merge branch 'multi-datasets' of https://github.com/EleutherAI/elk in…
lauritowal Mar 27, 2023
3765c4f
add self.
lauritowal Mar 27, 2023
2b05193
replace target with data
lauritowal Mar 27, 2023
83731bb
add self.
lauritowal Mar 27, 2023
764fda9
remove second arg
lauritowal Mar 27, 2023
d2c66b0
fix passing the wrong params for world size / rank
thejaminator Mar 28, 2023
9186326
Update prompt_loading.py
lauritowal Mar 28, 2023
3f99a4d
fix pre-commit errors
lauritowal Mar 28, 2023
148130d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
prompt loading refactor to enable better streaming
  • Loading branch information
AlexTMallen committed Mar 26, 2023
commit 3a762b0dd995ed61ee61820214ee67f804647b35
8 changes: 4 additions & 4 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from itertools import cycle
from random import Random
from torch.utils.data import IterableDataset as TorchIterableDataset
from typing import Iterator, Optional
from typing import Iterator, Optional, Iterable


@dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small detail, but maybe we should just remove @DataClass, if we use init anyway... (?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go ahead and change it

Expand All @@ -25,14 +25,14 @@ class BalancedSampler(TorchIterableDataset):
divided between the two binary label values (0 and 1). Defaults to 1000.
"""

def __init__(self, dataset: IterableDataset, buffer_size: int = 1000):
self.dataset = dataset
def __init__(self, data: Iterable[dict], buffer_size: int = 1000):
self.data = data

self.neg_buffer = deque(maxlen=buffer_size)
self.pos_buffer = deque(maxlen=buffer_size)

def __iter__(self):
for sample in self.dataset:
for sample in self.data:
label = sample["label"]

# Add the sample to the appropriate buffer
Expand Down
14 changes: 6 additions & 8 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import logging
import os
import torch
from itertools import islice


@dataclass
Expand Down Expand Up @@ -90,19 +91,14 @@ def extract_hiddens(
# Silence datasets logging messages from all but the first process
if rank != 0:
logging.disable(logging.CRITICAL)
if rank == 0 and cfg.prompts.num_variants >= 1:
print(f"Using {cfg.prompts.num_variants} prompts per example")

limits = cfg.prompts.max_examples
prompt_ds = load_prompts(
*cfg.prompts.datasets,
max_examples=limits[0 if split_type == "train" else 1],
split_type=split_type,
stream=cfg.prompts.stream,
rank=rank,
world_size=world_size,
)
num_variants = prompt_ds.features["prompts"].length
) # this dataset is already sharded, but hasn't been truncated to max_examples

# AutoModel should do the right thing here in nearly all cases. We don't actually
# care what head the model has, since we are just extracting hidden states.
Expand Down Expand Up @@ -131,7 +127,9 @@ def extract_hiddens(
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
# print(f"Using {prompt_ds} variants for each dataset")

for example in BalancedSampler(prompt_ds):
max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
for example in islice(BalancedSampler(prompt_ds), max_examples):
num_variants = len(example["prompts"])
hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
num_variants,
Expand All @@ -150,7 +148,7 @@ def extract_hiddens(

# Iterate over answers
for j in range(2):
text = record["text"][j]
text = record[j]["text"]
variant_inputs.append(text)

inputs = tokenizer(
Expand Down
138 changes: 62 additions & 76 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,20 @@ def __post_init__(self):

def load_prompts(
*dataset_strings: str,
max_examples: int = 0,
num_shots: int = 0,
num_variants: int = -1,
seed: int = 42,
shuffle: bool = True,
split_type: Literal["train", "val"] = "train",
stream: bool = False,
rank: int = 0,
world_size: int = 1,
) -> IterableDataset:
) -> Iterator[dict]:
"""Load a dataset full of prompts generated from the specified datasets.

Args:
dataset_strings: Space-delimited names of the HuggingFace datasets to use,
e.g. `"super_glue boolq"` or `"imdb"`.
max_examples: The maximum number of examples to use from the dataset.
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot.
seed: The seed to use for prompt randomization.
Expand All @@ -99,7 +98,6 @@ def load_prompts(
Returns:
An iterable dataset of prompts.
"""
prompt_datasets = []
prompters = []
raw_datasets = []
train_datasets = []
Expand All @@ -117,94 +115,82 @@ def load_prompts(
train_name, val_name = select_train_val_splits(ds_dict)
split_name = val_name if split_type == "val" else train_name

# If we're not streaming, take the opportunity to shuffle the dataset.
# Note that when streaming we can only approximately shuffle the dataset
# using a buffer. Streaming shuffling is NOT an adequate shuffle for
# datasets like IMDB, which are sorted by label.
bad_streaming_datasets = ["imdb"]
assert not (
stream and ds_name in bad_streaming_datasets
), f"Streaming is not supported for {ds_name}."
split = ds_dict[split_name].shuffle(seed=seed)
train_ds = ds_dict[train_name].shuffle(seed=seed)
if not stream:
ds = assert_type(Dataset, ds_dict[split_name].shuffle(seed=seed))
train_ds = assert_type(Dataset, ds_dict[train_name].shuffle(seed=seed))
split = ds.to_iterable_dataset().cast(ds.features)
else:
train_ds = assert_type(IterableDataset, ds_dict[train_name])
split = assert_type(IterableDataset, ds_dict[split_name])
split = assert_type(Dataset, split)
split = split.to_iterable_dataset().cast(split.features)

# only keep the datapoints relevant to the current process
if world_size > 1:
# This prints to stdout which is slightly annoying
split = split_dataset_by_node(split, world_size, rank)

raw_datasets.append(split)
train_datasets.append(train_ds)

num_variants = min(len(prompter.templates) for prompter in prompters)
min_num_templates = min(len(prompter.templates) for prompter in prompters)
num_variants = (
min_num_templates
if num_variants == -1
else min(num_variants, min_num_templates)
)
assert num_variants > 0
if rank == 0:
print(f"Using {num_variants} variants of each prompt")

ds_iterators = [iter(ds) for ds in raw_datasets]
while True: # terminates when the first dataset runs out of examples
for ds_iterator, ds, train_ds, prompter in zip(
ds_iterators, raw_datasets, train_datasets, prompters
):
label_column = infer_label_column(ds.features)
num_classes = infer_num_classes(ds.features[label_column])

# Remove everything except the label column
extra_cols = list(assert_type(Features, ds.features))
extra_cols.remove(label_column)

if label_column != "label":
ds = ds.rename_column(label_column, "label")
if num_shots > 0:
fewshot = FewShotSampler(
train_ds, # TODO: not iterator
num_shots=num_shots,
rng=rng,
)
fewshot_iter = iter(fewshot)
else:
fewshot_iter = None

for ds, train_ds, prompter in zip(raw_datasets, train_datasets, prompters):
label_column = infer_label_column(ds.features)
num_classes = infer_num_classes(ds.features[label_column])

# Remove everything except the label column
extra_cols = list(assert_type(Features, ds.features))
extra_cols.remove(label_column)
try:
example = next(ds_iterator)
except StopIteration:
return

if label_column != "label":
ds = ds.rename_column(label_column, "label")
if num_shots > 0:
fewshot = FewShotSampler(
train_ds,
num_shots=num_shots,
rng=rng,
)
fewshot_iter = iter(fewshot)
else:
fewshot_iter = None

# Canonicalize the name and dtype of the label column
ds = ds.map(
_convert_to_prompts,
fn_kwargs=dict(
example = _convert_to_prompts(
example,
label_column=label_column,
num_classes=num_classes,
num_variants=num_variants,
prompter=prompter,
rng=rng,
fewshot_iter=fewshot_iter,
),
remove_columns=extra_cols,
).map(
)

# Add the builder and config name to the records directly to make
# sure we don't forget what dataset they came from.
lambda _: dict(
builder_name=ds.info.builder_name,
config_name=ds.info.config_name,
),
# Explicit typing makes interleave_datasets work a lot faster
features=Features(
{
label_column: ClassLabel(names=["neg", "pos"]),
"builder_name": "string",
"config_name": "string",
"prompts": Sequence(
Sequence(
{"answer": "string", "text": "string"},
length=2, # contrast pair
),
length=num_variants,
),
"template_names": Sequence("string"),
}
),
)
prompt_datasets.append(ds)

master_ds = interleave_datasets(prompt_datasets)
if max_examples > 0:
master_ds = master_ds.take(max_examples)
if world_size > 1:
# This prints to stdout which is slightly annoying
master_ds = split_dataset_by_node(master_ds, rank, world_size)
if shuffle:
master_ds = master_ds.shuffle(seed=seed)

# Try to approximately shuffle the dataset if we're streaming. Note that this is
# NOT an adequate shuffle for datasets like IMDB, which are sorted by label.
if stream:
master_ds = master_ds.shuffle(seed=seed)

return master_ds
example["builder_name"] = ds.info.builder_name
example["config_name"] = ds.info.config_name

yield example


def _convert_to_prompts(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_load_prompts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from elk.extraction import load_prompts, PromptConfig
from elk.promptsource.templates import DatasetTemplates
from itertools import cycle
from itertools import cycle, islice
from typing import Literal
import pytest

Expand All @@ -10,7 +10,6 @@ def test_load_prompts():
def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]):
prompt_ds = load_prompts(
*cfg.datasets,
max_examples=cfg.max_examples[0],
shuffle=False,
split_type=split_type,
)
Expand All @@ -21,7 +20,8 @@ def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]):
prompter = DatasetTemplates(ds_name, config_name or None)
prompters.append(prompter)

for prompter, record in zip(cycle(prompters), prompt_ds):
limit = cfg.max_examples[0 if split_type == "train" else 1]
for prompter, record in zip(cycle(prompters), islice(prompt_ds, limit)):
true_template_names = prompter.all_template_names
returned_template_names = record["template_names"]

Expand Down