-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi datasets #123
Multi datasets #123
Changes from 1 commit
681698d
ac1b9f1
b864c77
7d7d97c
4fe61e9
fe61d67
74da878
569ef05
b62b679
fe94c22
bba24d8
03ba6e0
15ab351
a80369e
d304ab3
761c82d
f29743b
b7b7e23
1afb563
225d4c7
9368dc8
a858b65
5dc2ec6
b1b95e5
ee3911e
a55b3de
44dc25c
93d8d87
fad4d74
0a054f4
177eec2
3a762b0
f66c054
d3d87fc
3d08147
c9a43e1
94290aa
f9298e4
3765c4f
2b05193
83731bb
764fda9
d2c66b0
9186326
3f99a4d
148130d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
from ..utils import infer_label_column | ||
from collections import Counter | ||
from dataclasses import dataclass, field, InitVar | ||
from datasets import IterableDataset | ||
from itertools import cycle | ||
from torch.utils.data import IterableDataset as TorchIterableDataset | ||
from typing import Iterator, Optional | ||
import numpy as np | ||
|
||
|
||
@dataclass | ||
class BalancedSampler(TorchIterableDataset): | ||
""" | ||
Approximately balances a binary classification dataset in a streaming fashion. | ||
|
@@ -22,24 +24,27 @@ class BalancedSampler(TorchIterableDataset): | |
divided between the two binary label values (0 and 1). Defaults to 1000. | ||
""" | ||
|
||
def __init__(self, dataset: IterableDataset): | ||
self.dataset = dataset | ||
self.class_counts = np.zeros(2) | ||
dataset: IterableDataset | ||
label_counts: np.ndarray = field(default_factory=lambda: np.zeros(2)) | ||
seed: int = 42 | ||
|
||
def __post_init__(self): | ||
self.rng = np.random.default_rng(self.seed) | ||
|
||
def __iter__(self): | ||
for sample in self.dataset: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the dataset we're streaming from isn't shuffled (e.g. all the movie reviews about batman come first) our sampling will be distributionally incorrect, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also one of the tests broke. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah working on the test thing, it's a pretty bizarre error. The shuffling isn't really an issue. See the HF docs on this https://huggingface.co/docs/datasets/stream#shuffle. We just need to make sure we're actually calling |
||
label = sample["label"] | ||
|
||
# Update class counts | ||
self.class_counts[label] += 1 | ||
current_balance = self.class_counts / self.class_counts.sum() | ||
self.label_counts[label] += 1 | ||
current_balance = self.label_counts / self.label_counts.sum() | ||
|
||
# Check if the sample should be dropped | ||
majority_class = np.argmax(current_balance) | ||
if label == majority_class: | ||
# Solution of n * p * q / [n * (1 - p) + n * p * q] = 0.5 for q | ||
keep_prob = 1 / current_balance[majority_class] - 1 | ||
if np.random.rand() < 1 - keep_prob: | ||
if self.rng.uniform() < 1 - keep_prob: | ||
continue | ||
|
||
yield sample | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small detail, but maybe we should just remove @DataClass, if we use init anyway... (?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Go ahead and change it