Skip to content

Commit

Permalink
Fix smoke test failure
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Mar 25, 2023
1 parent 44dc25c commit 93d8d87
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 24 deletions.
35 changes: 15 additions & 20 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from ..utils import infer_label_column
from ..math_util import stochastic_round_constrained
from dataclasses import dataclass, field, InitVar
from ..utils import infer_label_column
from collections import deque
from dataclasses import dataclass
from datasets import IterableDataset
from itertools import cycle
from random import Random
from torch.utils.data import IterableDataset as TorchIterableDataset
from typing import Iterator, Optional
import numpy as np


@dataclass
Expand All @@ -25,30 +25,25 @@ class BalancedSampler(TorchIterableDataset):
divided between the two binary label values (0 and 1). Defaults to 1000.
"""

dataset: IterableDataset
label_counts: np.ndarray = field(default_factory=lambda: np.zeros(2))
seed: InitVar[int] = 42
def __init__(self, dataset: IterableDataset, buffer_size: int = 1000):
self.dataset = dataset

def __post_init__(self, seed: int):
self.rng = np.random.default_rng(seed)
self.neg_buffer = deque(maxlen=buffer_size)
self.pos_buffer = deque(maxlen=buffer_size)

def __iter__(self):
for sample in self.dataset:
label = sample["label"]

# Update class counts
self.label_counts[label] += 1
current_balance = self.label_counts / self.label_counts.sum()

# Check if the sample should be dropped
majority_class = np.argmax(current_balance)
if label == majority_class:
# Solution of n * p * q / [n * (1 - p) + n * p * q] = 0.5 for q
keep_prob = 1 / current_balance[majority_class] - 1
if self.rng.uniform() < 1 - keep_prob:
continue
# Add the sample to the appropriate buffer
if label == 0:
self.neg_buffer.append(sample)
else:
self.pos_buffer.append(sample)

yield sample
while self.neg_buffer and self.pos_buffer:
yield self.neg_buffer.popleft()
yield self.pos_buffer.popleft()


class FewShotSampler:
Expand Down
3 changes: 3 additions & 0 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def extract_hiddens(
*cfg.prompts.datasets,
max_examples=limits[0 if split_type == "train" else 1],
split_type=split_type,
stream=cfg.prompts.stream,
rank=rank,
world_size=world_size,
)
Expand Down Expand Up @@ -128,7 +129,9 @@ def extract_hiddens(
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
# print(f"Using {prompt_ds} variants for each dataset")

print("wowza")
for example in BalancedSampler(prompt_ds):
print("holy crap")
hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
num_variants,
Expand Down
3 changes: 3 additions & 0 deletions elk/extraction/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,8 @@ def _split_generators(self, dl_manager):

def _generate_examples(self, **gen_kwargs):
assert self.config.generator is not None, "generator must be specified"

print("wow")
for idx, ex in enumerate(self.config.generator(**gen_kwargs)):
print(f"iter {idx}")
yield idx, ex
31 changes: 27 additions & 4 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
interleave_datasets,
load_dataset,
ClassLabel,
Dataset,
Features,
IterableDataset,
Sequence,
Expand Down Expand Up @@ -40,9 +41,10 @@ class PromptConfig(Serializable):
the val split. If empty, use all examples. Defaults to empty.
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot. Defaults to 0.
seed: The seed to use for prompt randomization. Defaults to 42.
num_variants: The number of prompt templates to apply to each predicate upon
call to __getitem__. Use -1 to apply all available templates. Defaults to 1.
seed: The seed to use for prompt randomization. Defaults to 42.
stream: Whether to stream the dataset from the Internet. Defaults to False.
"""

datasets: list[str] = field(positional=True)
Expand All @@ -53,6 +55,7 @@ class PromptConfig(Serializable):
num_shots: int = 0
num_variants: int = -1
seed: int = 42
stream: bool = False

def __post_init__(self):
if len(self.max_examples) > 2:
Expand All @@ -69,6 +72,7 @@ def load_prompts(
seed: int = 42,
shuffle: bool = True,
split_type: Literal["train", "val"] = "train",
stream: bool = False,
rank: int = 0,
world_size: int = 1,
) -> IterableDataset:
Expand All @@ -82,6 +86,7 @@ def load_prompts(
are zero-shot.
seed: The seed to use for prompt randomization.
split_type: Whether to use the train or val split of the dataset.
stream: Whether to stream the dataset from the Internet. Defaults to False.
rank: The rank of the current process. Defaults to 0.
world_size: The number of processes. Defaults to 1.
Expand All @@ -101,14 +106,26 @@ def load_prompts(
prompters.append(DatasetTemplates(ds_name, config_name))

ds_dict = assert_type(
dict, load_dataset(ds_name, config_name or None, streaming=True)
dict, load_dataset(ds_name, config_name or None, streaming=stream)
)
train_name, val_name = select_train_val_splits(ds_dict)
split_name = val_name if split_type == "val" else train_name
raw_datasets.append(assert_type(IterableDataset, ds_dict[split_name]))
train_datasets.append(assert_type(IterableDataset, ds_dict[train_name]))

# If we're not streaming, take the opportunity to shuffle the dataset.
if not stream:
ds = assert_type(Dataset, ds_dict[split_name].shuffle(seed=seed))
train_ds = assert_type(Dataset, ds_dict[train_name].shuffle(seed=seed))
split = ds.to_iterable_dataset().cast(ds.features)
else:
train_ds = assert_type(IterableDataset, ds_dict[train_name])
split = assert_type(IterableDataset, ds_dict[split_name])

raw_datasets.append(split)
train_datasets.append(train_ds)

num_variants = min(len(prompter.templates) for prompter in prompters)
assert num_variants > 0

for ds, train_ds, prompter in zip(raw_datasets, train_datasets, prompters):
label_column = infer_label_column(ds.features)
num_classes = infer_num_classes(ds.features[label_column])
Expand Down Expand Up @@ -176,6 +193,11 @@ def load_prompts(
if shuffle:
master_ds = master_ds.shuffle(seed=seed)

# Try to approximately shuffle the dataset if we're streaming. Note that this is
# NOT an adequate shuffle for datasets like IMDB, which are sorted by label.
if stream:
master_ds = master_ds.shuffle(seed=seed)

return master_ds


Expand All @@ -189,6 +211,7 @@ def _convert_to_prompts(
fewshot_iter: Optional[Iterator[list[dict]]] = None,
) -> dict[str, Any]:
"""Prompt-generating function to pass to `IterableDataset.map`."""
print(f"label: {example[label_column]}")
prompts = []
templates = list(prompter.templates.values())
if num_variants < len(templates):
Expand Down

0 comments on commit 93d8d87

Please sign in to comment.