-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi datasets #123
Multi datasets #123
Changes from 15 commits
681698d
ac1b9f1
b864c77
7d7d97c
4fe61e9
fe61d67
74da878
569ef05
b62b679
fe94c22
bba24d8
03ba6e0
15ab351
a80369e
d304ab3
761c82d
f29743b
b7b7e23
1afb563
225d4c7
9368dc8
a858b65
5dc2ec6
b1b95e5
ee3911e
a55b3de
44dc25c
93d8d87
fad4d74
0a054f4
177eec2
3a762b0
f66c054
d3d87fc
3d08147
c9a43e1
94290aa
f9298e4
3765c4f
2b05193
83731bb
764fda9
d2c66b0
9186326
3f99a4d
148130d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .extraction import extract_hiddens, ExtractionConfig, PromptDataset | ||
from .extraction import extract_hiddens, ExtractionConfig |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .balanced_sampler import BalancedBatchSampler, BalancedSampler | ||
from .extraction import ExtractionConfig, extract_hiddens, extract | ||
from .generator import _GeneratorConfig, _GeneratorBuilder | ||
from .prompt_dataset import PromptDataset, PromptConfig | ||
from .prompt_loading import PromptConfig, load_prompts |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from ..utils import infer_label_column | ||
from collections import Counter | ||
from dataclasses import dataclass, field, InitVar | ||
from datasets import IterableDataset | ||
from itertools import cycle | ||
from torch.utils.data import IterableDataset as TorchIterableDataset | ||
from typing import Iterator, Optional | ||
import numpy as np | ||
|
||
|
||
@dataclass | ||
class BalancedSampler(TorchIterableDataset): | ||
""" | ||
Approximately balances a binary classification dataset in a streaming fashion. | ||
Written mostly by GPT-4. | ||
|
||
Args: | ||
dataset (IterableDataset): The HuggingFace IterableDataset to balance. | ||
label_col (Optional[str], optional): The name of the column containing the | ||
binary label. If not provided, the label column will be inferred from | ||
the dataset features. Defaults to None. | ||
buffer_size (int, optional): The total buffer size to use for balancing the | ||
dataset. This value should be divisible by 2, as it will be equally | ||
divided between the two binary label values (0 and 1). Defaults to 1000. | ||
""" | ||
|
||
dataset: IterableDataset | ||
label_counts: np.ndarray = field(default_factory=lambda: np.zeros(2)) | ||
seed: InitVar[int] = 42 | ||
|
||
def __post_init__(self, seed: int): | ||
self.rng = np.random.default_rng(seed) | ||
|
||
def __iter__(self): | ||
for sample in self.dataset: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the dataset we're streaming from isn't shuffled (e.g. all the movie reviews about batman come first) our sampling will be distributionally incorrect, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also one of the tests broke. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah working on the test thing, it's a pretty bizarre error. The shuffling isn't really an issue. See the HF docs on this https://huggingface.co/docs/datasets/stream#shuffle. We just need to make sure we're actually calling |
||
label = sample["label"] | ||
|
||
# Update class counts | ||
self.label_counts[label] += 1 | ||
current_balance = self.label_counts / self.label_counts.sum() | ||
|
||
# Check if the sample should be dropped | ||
majority_class = np.argmax(current_balance) | ||
if label == majority_class: | ||
# Solution of n * p * q / [n * (1 - p) + n * p * q] = 0.5 for q | ||
keep_prob = 1 / current_balance[majority_class] - 1 | ||
if self.rng.uniform() < 1 - keep_prob: | ||
continue | ||
|
||
yield sample | ||
|
||
|
||
class BalancedBatchSampler: | ||
"""Yields precisely balanced batches from a binary classification dataset. | ||
|
||
Written by a human being because GPT-4 couldn't figure out how to do it. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: IterableDataset, | ||
label_col: Optional[str] = None, | ||
batch_size: int = 32, | ||
): | ||
self.batch_size = batch_size | ||
self.dataset = dataset | ||
self.label_col = label_col or infer_label_column(dataset.features) | ||
|
||
def __iter__(self) -> Iterator[list[dict]]: | ||
batch = [] | ||
|
||
max_count = self.batch_size // 2 | ||
label_counts = Counter() | ||
|
||
# Infinite loop! | ||
for sample in cycle(self.dataset): | ||
label = sample[self.label_col] | ||
if label_counts[label] >= max_count: | ||
continue | ||
|
||
batch.append(sample) | ||
label_counts[label] += 1 | ||
|
||
if len(batch) == self.batch_size: | ||
yield batch | ||
|
||
batch = [] | ||
label_counts.clear() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small detail, but maybe we should just remove @DataClass, if we use init anyway... (?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Go ahead and change it