Skip to content

Commit

Permalink
Merge pull request #123 from EleutherAI/multi-datasets
Browse files Browse the repository at this point in the history
Multi datasets
  • Loading branch information
lauritowal committed Mar 28, 2023
2 parents 12e2c12 + 148130d commit 77e2aeb
Show file tree
Hide file tree
Showing 26 changed files with 616 additions and 539 deletions.
2 changes: 1 addition & 1 deletion elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .extraction import extract_hiddens, Extract, PromptDataset
from .extraction import extract_hiddens, Extract
4 changes: 3 additions & 1 deletion elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Main entry point for `elk`."""

from dataclasses import dataclass
from pathlib import Path
from typing import Union

from simple_parsing import ArgumentParser

from elk.evaluation.evaluate import Eval
from elk.extraction.extraction import Extract

from elk.training.train import Elicit


Expand Down
Empty file added elk/evaluation/__init__.py
Empty file.
30 changes: 24 additions & 6 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
import csv
import os
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Literal, Optional, Callable
from typing import Callable, Literal, Optional, cast

import torch
from simple_parsing import Serializable, field
import torch.multiprocessing as mp
from simple_parsing.helpers import Serializable, field
from torch import Tensor
from tqdm.auto import tqdm

from datasets import DatasetDict
from elk.evaluation.evaluate_log import EvalLog
from elk.extraction.extraction import Extract
from elk.files import elk_reporter_dir
from elk.run import Run
from elk.training import Reporter
from elk.evaluation.evaluate_log import EvalLog
from elk.utils import select_usable_devices

from ..files import elk_reporter_dir, memorably_named_dir
from ..training.preprocessing import normalize
from ..utils import (
assert_type,
int16_to_float32,
select_train_val_splits,
select_usable_devices,
)


@dataclass
Expand Down Expand Up @@ -39,6 +52,8 @@ class Eval(Serializable):
out_dir: Optional[Path] = None
num_gpus: int = -1

concatenated_layer_offset: int = 0

def execute(self):
transfer_eval = elk_reporter_dir() / self.source / "transfer_eval"

Expand Down Expand Up @@ -80,7 +95,10 @@ def evaluate_reporter(

def evaluate(self):
"""Evaluate the reporter on all layers."""
devices = select_usable_devices(self.cfg.num_gpus)
devices = select_usable_devices(
self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem
)

num_devices = len(devices)
func: Callable[[int], EvalLog] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
Expand Down
3 changes: 2 additions & 1 deletion elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .balanced_sampler import BalancedSampler, FewShotSampler
from .extraction import Extract, extract_hiddens, extract
from .generator import _GeneratorConfig, _GeneratorBuilder
from .prompt_dataset import PromptDataset, PromptConfig
from .prompt_loading import PromptConfig, load_prompts
93 changes: 93 additions & 0 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from ..math_util import stochastic_round_constrained
from ..utils import infer_label_column
from collections import deque
from dataclasses import dataclass
from datasets import IterableDataset
from itertools import cycle
from random import Random
from torch.utils.data import IterableDataset as TorchIterableDataset
from typing import Iterator, Optional, Iterable


class BalancedSampler(TorchIterableDataset):
"""
Approximately balances a binary classification dataset in a streaming fashion.
Written mostly by GPT-4.
Args:
dataset (IterableDataset): The HuggingFace IterableDataset to balance.
label_col (Optional[str], optional): The name of the column containing the
binary label. If not provided, the label column will be inferred from
the dataset features. Defaults to None.
buffer_size (int, optional): The total buffer size to use for balancing the
dataset. This value should be divisible by 2, as it will be equally
divided between the two binary label values (0 and 1). Defaults to 1000.
"""

def __init__(self, data: Iterable[dict], buffer_size: int = 1000):
self.data = data

self.neg_buffer = deque(maxlen=buffer_size)
self.pos_buffer = deque(maxlen=buffer_size)

def __iter__(self):
for sample in self.data:
label = sample["label"]

# Add the sample to the appropriate buffer
if label == 0:
self.neg_buffer.append(sample)
else:
self.pos_buffer.append(sample)

while self.neg_buffer and self.pos_buffer:
yield self.neg_buffer.popleft()
yield self.pos_buffer.popleft()


class FewShotSampler:
"""Yields batches of few-shot examples that are as balanced as possible.
If the number of examples is divisible by the number of shots, this sampler
will yield batches of exactly `num_shots` examples. Otherwise, it will
use `stochastic_round_constrained` to get as close to balanced batches as
possible.
"""

def __init__(
self,
dataset: IterableDataset,
num_shots: int,
rng: Random,
label_col: Optional[str] = None,
):
self.dataset = dataset
self.label_col = label_col or infer_label_column(dataset.features)
self.num_shots = num_shots
self.rng = rng

def __iter__(self) -> Iterator[list[dict]]:
neg_buf, pos_buf = [], []

# Infinite loop over the dataset!
for sample in cycle(self.dataset):
label = sample[self.label_col]
if label == 0:
neg_buf.append(sample)
elif label == 1:
pos_buf.append(sample)
else:
raise ValueError(f"Expected label to be 0 or 1, got {label}")

neg_count, pos_count = stochastic_round_constrained(
[self.num_shots / 2, self.num_shots / 2], self.rng
)
while len(neg_buf) >= neg_count and len(pos_buf) >= pos_count:
batch = []
for _ in range(neg_count):
batch.append(neg_buf.pop())
for _ in range(pos_count):
batch.append(pos_buf.pop())

self.rng.shuffle(batch)
yield batch
Loading

0 comments on commit 77e2aeb

Please sign in to comment.