-
Notifications
You must be signed in to change notification settings - Fork 32
/
test_samplers.py
58 lines (44 loc) · 2 KB
/
test_samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from collections import Counter
from itertools import islice
from random import Random
from datasets import IterableDataset, load_dataset
from elk.extraction import BalancedSampler, FewShotSampler
from elk.utils import assert_type, infer_label_column
def test_output_batches_are_balanced():
# Load an example dataset for testing
dataset = assert_type(
IterableDataset,
load_dataset("super_glue", "boolq", split="train", streaming=True),
)
label_col = infer_label_column(dataset.features) # type: ignore
# Start with an even number of shots; make sure they're exactly balanced
sampler = FewShotSampler(dataset, 6, rng=Random(42))
for batch in islice(sampler, 5):
counter = Counter(sample[label_col] for sample in batch)
# Check if the output batch is balanced
assert counter[0] == counter[1]
# Start with an odd number of shots; make sure they're roughly balanced
sampler = FewShotSampler(dataset, 5, rng=Random(42))
for batch in islice(sampler, 5):
counter = Counter(sample[label_col] for sample in batch)
# The batch should be balanced to within 1 sample
assert abs(counter[0] - counter[1]) <= 1
def test_output_is_roughly_balanced():
# Load an example dataset for testing
dataset = assert_type(
IterableDataset,
load_dataset("super_glue", "boolq", split="train", streaming=True),
)
col = infer_label_column(dataset.features) # type: ignore
reservoir = BalancedSampler(dataset, {0, 1})
# Count the number of samples for each label
counter = Counter()
for sample in islice(reservoir, 3000):
counter[sample[col]] += 1
# Check if the output is roughly balanced
label_0_count = counter[0]
label_1_count = counter[1]
imbalance = abs(label_0_count - label_1_count) / (label_0_count + label_1_count)
# Set a tolerance threshold for the imbalance ratio (e.g., 1%)
tol = 0.01
assert imbalance < tol, f"Imbalance ratio {imbalance} exceeded tolerance {tol}"