Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check is streamable #166

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
check is streamable on-the-fly
  • Loading branch information
AlexTMallen committed Apr 8, 2023
commit 7aff4562e418ad3f91e57556f110ba939d50a6e3
39 changes: 34 additions & 5 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from ..utils import infer_label_column
from ..utils.typing import assert_type
from collections import deque
from dataclasses import dataclass
from datasets import IterableDataset, Features
from itertools import cycle
from itertools import cycle, islice
from random import Random
from torch.utils.data import IterableDataset as TorchIterableDataset
from typing import Iterator, Optional, Iterable
Expand All @@ -25,11 +24,17 @@ class BalancedSampler(TorchIterableDataset):
divided between the two binary label values (0 and 1). Defaults to 1000.
"""

def __init__(self, data: Iterable[dict], buffer_size: int = 1000):
def __init__(
self, data: Iterable[dict], max_examples: int, min_buffer_size: int = 100
):
self.data = data

self.neg_buffer = deque(maxlen=buffer_size)
self.pos_buffer = deque(maxlen=buffer_size)
self.max_examples = max_examples
self.buffer_size = max(min_buffer_size, max_examples)
self.neg_buffer = deque(maxlen=self.buffer_size)
self.pos_buffer = deque(maxlen=self.buffer_size)

self.idx = 0 # The number of samples yielded so far

def __iter__(self):
for sample in self.data:
Expand All @@ -41,9 +46,33 @@ def __iter__(self):
else:
self.pos_buffer.append(sample)

# Check if the input was too unbalanced to begin with
if self.idx == 0 and (
not self.neg_buffer
and len(self.pos_buffer) == self.buffer_size
or not self.pos_buffer
and len(self.neg_buffer) == self.buffer_size
):
raise ValueError(
"The input dataset was too unbalanced to balance while streaming. "
"If streaming a dataset such as IMDB where the data is sorted by "
"label, streaming data cannot be balanced. "
"Try removing the `--stream` flag."
)

while self.neg_buffer and self.pos_buffer:
yield self.neg_buffer.popleft()
self.idx += 1
if self.idx == self.max_examples:
return

yield self.pos_buffer.popleft()
self.idx += 1
if self.idx == self.max_examples:
return

def __len__(self):
return self.max_examples


class FewShotSampler:
Expand Down
3 changes: 1 addition & 2 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def extract_hiddens(
*cfg.prompts.datasets,
split_type=split_type,
stream=cfg.prompts.stream,
max_examples=max_examples,
rank=rank,
world_size=world_size,
) # this dataset is already sharded, but hasn't been truncated to max_examples
Expand Down Expand Up @@ -136,7 +135,7 @@ def extract_hiddens(
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
# print(f"Using {prompt_ds} variants for each dataset")

for example in islice(BalancedSampler(prompt_ds), max_examples):
for example in BalancedSampler(prompt_ds, max_examples):
num_variants = len(example["prompts"])
hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
Expand Down
12 changes: 0 additions & 12 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@
binarize,
infer_label_column,
infer_num_classes,
is_streamable,
select_train_val_splits,
)
from .balanced_sampler import FewShotSampler
from dataclasses import dataclass
from datasets import (
interleave_datasets,
load_dataset,
ClassLabel,
Dataset,
Features,
IterableDataset,
Sequence,
)
from datasets.distributed import split_dataset_by_node
from random import Random
Expand Down Expand Up @@ -79,7 +74,6 @@ def yield_prompts(
seed: int = 42,
split_type: Literal["train", "val"] = "train",
stream: bool = False,
max_examples: int = 750,
rank: int = 0,
world_size: int = 1,
) -> Iterator[dict]:
Expand Down Expand Up @@ -121,12 +115,6 @@ def yield_prompts(
if not stream:
split = assert_type(Dataset, split)
split = split.to_iterable_dataset().cast(split.features)
else:
if not is_streamable(split, max_examples=max_examples):
raise ValueError(
f"Streaming dataset {ds_name} is not streamable because the first "
f"{max_examples} examples are all of the same label."
)

# only keep the datapoints relevant to the current process
if world_size > 1:
Expand Down
1 change: 0 additions & 1 deletion elk/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
get_columns_all_equal,
infer_label_column,
infer_num_classes,
is_streamable,
select_train_val_splits,
)

Expand Down
12 changes: 0 additions & 12 deletions elk/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,3 @@ def binarize(template: Template, label: int, new_label: int, rng: Random) -> Tem
)

return new_template


def is_streamable(ds: IterableDataset, max_examples: int) -> bool:
"""Checks that the first `max_examples` are not all of the same label.

Note that when streaming we can only approximately shuffle the dataset
using a buffer. Streaming shuffling is NOT an adequate shuffle for
datasets like IMDB, which are sorted by label.
"""
label_column = infer_label_column(assert_type(Features, ds.features))
labels = [ex[label_column] for ex in islice(ds, max_examples)]
return len(set(labels)) > 1
2 changes: 1 addition & 1 deletion tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_output_is_roughly_balanced():
)

col = infer_label_column(dataset.features)
reservoir = BalancedSampler(dataset)
reservoir = BalancedSampler(dataset, max_examples=3000)

# Count the number of samples for each label
counter = Counter()
Expand Down