Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check is streamable #166

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
check is streamable
  • Loading branch information
AlexTMallen committed Apr 4, 2023
commit 3a24b56e32d15de881c1738ee92150ec53cf6cff
2 changes: 1 addition & 1 deletion elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .balanced_sampler import BalancedSampler, FewShotSampler
from .extraction import Extract, extract_hiddens, extract
from .generator import _GeneratorConfig, _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts
from .prompt_loading import PromptConfig, yield_prompts
21 changes: 10 additions & 11 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from .balanced_sampler import BalancedSampler
from .generator import _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts
from .prompt_loading import PromptConfig, yield_prompts


@dataclass
Expand Down Expand Up @@ -93,10 +93,18 @@ def extract_hiddens(
if rank != 0:
logging.disable(logging.CRITICAL)

prompt_ds = load_prompts(
global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
# the last process gets the remainder (which is usually small)
if rank == world_size - 1:
max_examples += global_max_examples % world_size

prompt_ds = yield_prompts(
*cfg.prompts.datasets,
split_type=split_type,
stream=cfg.prompts.stream,
max_examples=max_examples,
rank=rank,
world_size=world_size,
) # this dataset is already sharded, but hasn't been truncated to max_examples
Expand Down Expand Up @@ -128,15 +136,6 @@ def extract_hiddens(
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
# print(f"Using {prompt_ds} variants for each dataset")

global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
# the last process gets the remainder (which is usually small)
if rank == world_size - 1:
max_examples += global_max_examples % world_size

print(f"Extracting {max_examples} examples from {prompt_ds} on {device}")

for example in islice(BalancedSampler(prompt_ds), max_examples):
num_variants = len(example["prompts"])
hidden_dict = {
Expand Down
17 changes: 9 additions & 8 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
binarize,
infer_label_column,
infer_num_classes,
is_streamable,
select_train_val_splits,
)
from .balanced_sampler import FewShotSampler
Expand Down Expand Up @@ -71,13 +72,14 @@ def __post_init__(self):
self.max_examples *= 2


def load_prompts(
def yield_prompts(
*dataset_strings: str,
num_shots: int = 0,
num_variants: int = -1,
seed: int = 42,
split_type: Literal["train", "val"] = "train",
stream: bool = False,
max_examples: int = 750,
rank: int = 0,
world_size: int = 1,
) -> Iterator[dict]:
Expand Down Expand Up @@ -114,18 +116,17 @@ def load_prompts(
train_name, val_name = select_train_val_splits(ds_dict)
split_name = val_name if split_type == "val" else train_name

# Note that when streaming we can only approximately shuffle the dataset
# using a buffer. Streaming shuffling is NOT an adequate shuffle for
# datasets like IMDB, which are sorted by label.
bad_streaming_datasets = ["imdb"]
assert not (
stream and ds_name in bad_streaming_datasets
), f"Streaming is not supported for {ds_name}."
split = ds_dict[split_name].shuffle(seed=seed)
train_ds = ds_dict[train_name].shuffle(seed=seed)
if not stream:
split = assert_type(Dataset, split)
split = split.to_iterable_dataset().cast(split.features)
else:
if not is_streamable(split, max_examples=max_examples):
raise ValueError(
f"Streaming dataset {ds_name} is not streamable because the first "
f"{max_examples} examples are all of the same label."
)

# only keep the datapoints relevant to the current process
if world_size > 1:
Expand Down
1 change: 1 addition & 0 deletions elk/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
get_columns_all_equal,
infer_label_column,
infer_num_classes,
is_streamable,
select_train_val_splits,
)

Expand Down
18 changes: 15 additions & 3 deletions elk/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
ClassLabel,
DatasetDict,
Features,
IterableDataset,
Split,
Value,
)
from random import Random
import torch
from typing import Iterable, Optional, List, Any
import numpy as np
from itertools import islice
from typing import Iterable, List, Any
import copy


Expand Down Expand Up @@ -120,3 +120,15 @@ def binarize(template: Template, label: int, new_label: int, rng: Random) -> Tem
)

return new_template


def is_streamable(ds: IterableDataset, max_examples: int) -> bool:
norabelrose marked this conversation as resolved.
Show resolved Hide resolved
"""Checks that the first `max_examples` are not all of the same label.

Note that when streaming we can only approximately shuffle the dataset
using a buffer. Streaming shuffling is NOT an adequate shuffle for
datasets like IMDB, which are sorted by label.
"""
label_column = infer_label_column(assert_type(Features, ds.features))
labels = [ex[label_column] for ex in islice(ds, max_examples)]
return len(set(labels)) > 1