Skip to content

Commit

Permalink
Make BalancedSampler deterministic
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Mar 22, 2023
1 parent 15ab351 commit a80369e
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from ..utils import infer_label_column
from collections import Counter
from dataclasses import dataclass, field, InitVar
from datasets import IterableDataset
from itertools import cycle
from torch.utils.data import IterableDataset as TorchIterableDataset
from typing import Iterator, Optional
import numpy as np


@dataclass
class BalancedSampler(TorchIterableDataset):
"""
Approximately balances a binary classification dataset in a streaming fashion.
Expand All @@ -22,24 +24,27 @@ class BalancedSampler(TorchIterableDataset):
divided between the two binary label values (0 and 1). Defaults to 1000.
"""

def __init__(self, dataset: IterableDataset):
self.dataset = dataset
self.class_counts = np.zeros(2)
dataset: IterableDataset
label_counts: np.ndarray = field(default_factory=lambda: np.zeros(2))
seed: int = 42

def __post_init__(self):
self.rng = np.random.default_rng(self.seed)

def __iter__(self):
for sample in self.dataset:
label = sample["label"]

# Update class counts
self.class_counts[label] += 1
current_balance = self.class_counts / self.class_counts.sum()
self.label_counts[label] += 1
current_balance = self.label_counts / self.label_counts.sum()

# Check if the sample should be dropped
majority_class = np.argmax(current_balance)
if label == majority_class:
# Solution of n * p * q / [n * (1 - p) + n * p * q] = 0.5 for q
keep_prob = 1 / current_balance[majority_class] - 1
if np.random.rand() < 1 - keep_prob:
if self.rng.uniform() < 1 - keep_prob:
continue

yield sample
Expand Down

0 comments on commit a80369e

Please sign in to comment.