Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check is streamable #166

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from .balanced_sampler import BalancedSampler, FewShotSampler
from .extraction import Extract, extract, extract_hiddens
from .generator import _GeneratorBuilder, _GeneratorConfig
from .prompt_loading import PromptConfig, load_prompts
from .prompt_loading import PromptConfig, yield_prompts

__all__ = [
"_GeneratorBuilder",
"_GeneratorConfig",
"BalancedSampler",
"FewShotSampler",
"Extract",
"extract_hiddens",
"extract",
"_GeneratorConfig",
"_GeneratorBuilder",
"Extract",
"FewShotSampler",
"PromptConfig",
"load_prompts",
"yield_prompts",
]
36 changes: 33 additions & 3 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,17 @@ class BalancedSampler(TorchIterableDataset):
divided between the two binary label values (0 and 1). Defaults to 1000.
"""

def __init__(self, data: Iterable[dict], buffer_size: int = 1000):
def __init__(
self, data: Iterable[dict], max_examples: int, min_buffer_size: int = 100
):
self.data = data

self.neg_buffer = deque(maxlen=buffer_size)
self.pos_buffer = deque(maxlen=buffer_size)
self.max_examples = max_examples
self.buffer_size = max(min_buffer_size, max_examples)
self.neg_buffer = deque(maxlen=self.buffer_size)
self.pos_buffer = deque(maxlen=self.buffer_size)

self.idx = 0 # The number of samples yielded so far

def __iter__(self):
for sample in self.data:
Expand All @@ -41,9 +47,33 @@ def __iter__(self):
else:
self.pos_buffer.append(sample)

# Check if the input was too unbalanced to begin with
if self.idx == 0 and (
not self.neg_buffer
and len(self.pos_buffer) == self.buffer_size
or not self.pos_buffer
and len(self.neg_buffer) == self.buffer_size
):
raise ValueError(
"The input dataset was too unbalanced to balance while streaming. "
"If streaming a dataset such as IMDB where the data is sorted by "
"label, streaming data cannot be balanced. "
"Try removing the `--stream` flag."
)

while self.neg_buffer and self.pos_buffer:
yield self.neg_buffer.popleft()
self.idx += 1
if self.idx == self.max_examples:
return

yield self.pos_buffer.popleft()
self.idx += 1
if self.idx == self.max_examples:
return

def __len__(self):
return self.max_examples


class FewShotSampler:
Expand Down
21 changes: 10 additions & 11 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Any, Iterable, Literal, Optional

import torch
Expand Down Expand Up @@ -34,7 +33,7 @@
)
from .balanced_sampler import BalancedSampler
from .generator import _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts
from .prompt_loading import PromptConfig, yield_prompts


@dataclass
Expand Down Expand Up @@ -97,7 +96,14 @@ def extract_hiddens(
if rank != 0:
logging.disable(logging.CRITICAL)

prompt_ds = load_prompts(
global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
# the last process gets the remainder (which is usually small)
if rank == world_size - 1:
max_examples += global_max_examples % world_size

prompt_ds = yield_prompts(
*cfg.prompts.datasets,
split_type=split_type,
stream=cfg.prompts.stream,
Expand All @@ -118,14 +124,7 @@ def extract_hiddens(
# Iterating over questions
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))

global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
# the last process gets the remainder (which is usually small)
if rank == world_size - 1:
max_examples += global_max_examples % world_size

for example in islice(BalancedSampler(prompt_ds), max_examples):
for example in BalancedSampler(prompt_ds, max_examples):
num_variants = len(example["prompts"])
hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
Expand Down
9 changes: 1 addition & 8 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __post_init__(self):
self.max_examples *= 2


def load_prompts(
def yield_prompts(
*dataset_strings: str,
num_shots: int = 0,
num_variants: int = -1,
Expand Down Expand Up @@ -113,13 +113,6 @@ def load_prompts(
train_name, val_name = select_train_val_splits(ds_dict)
split_name = val_name if split_type == "val" else train_name

# Note that when streaming we can only approximately shuffle the dataset
# using a buffer. Streaming shuffling is NOT an adequate shuffle for
# datasets like IMDB, which are sorted by label.
bad_streaming_datasets = ["imdb"]
assert not (
stream and ds_name in bad_streaming_datasets
), f"Streaming is not supported for {ds_name}."
split = ds_dict[split_name].shuffle(seed=seed)
train_ds = ds_dict[train_name].shuffle(seed=seed)
if not stream:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_load_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import pytest

from elk.extraction import PromptConfig, load_prompts
from elk.extraction import PromptConfig, yield_prompts
from elk.promptsource.templates import DatasetTemplates


@pytest.mark.filterwarnings("ignore:Unable to find a decoding function")
def test_load_prompts():
def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]):
prompt_ds = load_prompts(
prompt_ds = yield_prompts(
*cfg.datasets,
split_type=split_type,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_output_is_roughly_balanced():
)

col = infer_label_column(dataset.features)
reservoir = BalancedSampler(dataset)
reservoir = BalancedSampler(dataset, max_examples=3000)

# Count the number of samples for each label
counter = Counter()
Expand Down