-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi datasets #123
Multi datasets #123
Changes from 3 commits
681698d
ac1b9f1
b864c77
7d7d97c
4fe61e9
fe61d67
74da878
569ef05
b62b679
fe94c22
bba24d8
03ba6e0
15ab351
a80369e
d304ab3
761c82d
f29743b
b7b7e23
1afb563
225d4c7
9368dc8
a858b65
5dc2ec6
b1b95e5
ee3911e
a55b3de
44dc25c
93d8d87
fad4d74
0a054f4
177eec2
3a762b0
f66c054
d3d87fc
3d08147
c9a43e1
94290aa
f9298e4
3765c4f
2b05193
83731bb
764fda9
d2c66b0
9186326
3f99a4d
148130d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .balanced_sampler import BalancedBatchSampler, BalancedSampler | ||
from .balanced_sampler import BalancedSampler, FewShotSampler | ||
from .extraction import ExtractionConfig, extract_hiddens, extract | ||
from .generator import _GeneratorConfig, _GeneratorBuilder | ||
from .prompt_loading import PromptConfig, load_prompts |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
from ..utils import infer_label_column | ||
from collections import Counter | ||
from ..math_util import stochastic_round_constrained | ||
from dataclasses import dataclass, field, InitVar | ||
from datasets import IterableDataset | ||
from itertools import cycle | ||
from random import Random | ||
from torch.utils.data import IterableDataset as TorchIterableDataset | ||
from typing import Iterator, Optional | ||
import numpy as np | ||
|
@@ -50,39 +51,49 @@ def __iter__(self): | |
yield sample | ||
|
||
|
||
class BalancedBatchSampler: | ||
"""Yields precisely balanced batches from a binary classification dataset. | ||
class FewShotSampler: | ||
"""Yields batches of few-shot examples that are as balanced as possible. | ||
|
||
Written by a human being because GPT-4 couldn't figure out how to do it. | ||
If the number of examples is divisible by the number of shots, this sampler | ||
will yield batches of exactly `num_shots` examples. Otherwise, it will | ||
use `stochastic_round_constrained` to get as close to balanced batches as | ||
possible. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: IterableDataset, | ||
num_shots: int, | ||
rng: Random, | ||
label_col: Optional[str] = None, | ||
batch_size: int = 32, | ||
): | ||
self.batch_size = batch_size | ||
self.dataset = dataset | ||
self.label_col = label_col or infer_label_column(dataset.features) | ||
self.num_shots = num_shots | ||
self.rng = rng | ||
|
||
def __iter__(self) -> Iterator[list[dict]]: | ||
batch = [] | ||
neg_buf, pos_buf = [], [] | ||
|
||
max_count = self.batch_size // 2 | ||
label_counts = Counter() | ||
|
||
# Infinite loop! | ||
# Infinite loop over the dataset! | ||
for sample in cycle(self.dataset): | ||
label = sample[self.label_col] | ||
if label_counts[label] >= max_count: | ||
continue | ||
|
||
batch.append(sample) | ||
label_counts[label] += 1 | ||
if label == 0: | ||
neg_buf.append(sample) | ||
elif label == 1: | ||
pos_buf.append(sample) | ||
else: | ||
raise ValueError(f"Expected label to be 0 or 1, got {label}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we don't support few-shot examples for dbpedia or ag_news (or multiclass datasets)? I guess this is fine until someone wants to do this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I don't want to think about how to support that rn |
||
|
||
neg_count, pos_count = stochastic_round_constrained( | ||
[self.num_shots / 2, self.num_shots / 2], self.rng | ||
) | ||
while len(neg_buf) >= neg_count and len(pos_buf) >= pos_count: | ||
batch = [] | ||
for _ in range(neg_count): | ||
batch.append(neg_buf.pop()) | ||
for _ in range(pos_count): | ||
batch.append(pos_buf.pop()) | ||
|
||
if len(batch) == self.batch_size: | ||
self.rng.shuffle(batch) | ||
yield batch | ||
|
||
batch = [] | ||
label_counts.clear() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
from ..math_util import stochastic_round_constrained | ||
from ..promptsource import DatasetTemplates | ||
from ..utils import ( | ||
assert_type, | ||
|
@@ -7,6 +6,7 @@ | |
infer_num_classes, | ||
select_train_val_splits, | ||
) | ||
from .balanced_sampler import FewShotSampler | ||
from dataclasses import dataclass | ||
from datasets import ( | ||
interleave_datasets, | ||
|
@@ -19,7 +19,7 @@ | |
from datasets.distributed import split_dataset_by_node | ||
from random import Random | ||
from simple_parsing.helpers import field, Serializable | ||
from typing import Any, Literal, Optional | ||
from typing import Any, Iterator, Literal, Optional | ||
|
||
|
||
@dataclass | ||
|
@@ -65,6 +65,7 @@ def __post_init__(self): | |
def load_prompts( | ||
*dataset_strings: str, | ||
max_examples: int = 0, | ||
num_shots: int = 0, | ||
seed: int = 42, | ||
split_type: Literal["train", "val"] = "train", | ||
rank: int = 0, | ||
|
@@ -76,6 +77,8 @@ def load_prompts( | |
dataset_strings: Space-delimited names of the HuggingFace datasets to use, | ||
e.g. `"super_glue boolq"` or `"imdb"`. | ||
max_examples: The maximum number of examples to use from the dataset. | ||
num_shots: The number of examples to use in few-shot prompts. If zero, prompts | ||
are zero-shot. | ||
seed: The seed to use for prompt randomization. | ||
split_type: Whether to use the train or val split of the dataset. | ||
rank: The rank of the current process. Defaults to 0. | ||
|
@@ -87,6 +90,7 @@ def load_prompts( | |
prompt_datasets = [] | ||
prompters = [] | ||
raw_datasets = [] | ||
train_datasets = [] | ||
rng = Random(seed) | ||
|
||
# First load the datasets and prompters. We need to know the minimum number of | ||
|
@@ -101,9 +105,10 @@ def load_prompts( | |
train_name, val_name = select_train_val_splits(ds_dict) | ||
split_name = val_name if split_type == "val" else train_name | ||
raw_datasets.append(assert_type(IterableDataset, ds_dict[split_name])) | ||
train_datasets.append(assert_type(IterableDataset, ds_dict[train_name])) | ||
|
||
num_variants = min(len(prompter.templates) for prompter in prompters) | ||
for ds, prompter in zip(raw_datasets, prompters): | ||
for ds, train_ds, prompter in zip(raw_datasets, train_datasets, prompters): | ||
label_column = infer_label_column(ds.features) | ||
num_classes = infer_num_classes(ds.features[label_column]) | ||
|
||
|
@@ -113,6 +118,15 @@ def load_prompts( | |
|
||
if label_column != "label": | ||
ds = ds.rename_column(label_column, "label") | ||
if num_shots > 0: | ||
fewshot = FewShotSampler( | ||
train_ds, | ||
num_shots=num_shots, | ||
rng=rng, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the function signature is
so we were passing world_size to rank and vice versa |
||
fewshot_iter = iter(fewshot) | ||
else: | ||
fewshot_iter = None | ||
|
||
# Canonicalize the name and dtype of the label column | ||
ds = ds.map( | ||
|
@@ -123,6 +137,7 @@ def load_prompts( | |
num_variants=num_variants, | ||
prompter=prompter, | ||
rng=rng, | ||
fewshot_iter=fewshot_iter, | ||
), | ||
remove_columns=extra_cols, | ||
).map( | ||
|
@@ -155,6 +170,7 @@ def load_prompts( | |
if max_examples > 0: | ||
master_ds = master_ds.take(max_examples) | ||
if world_size > 1: | ||
# This prints to stdout which is slightly annoying | ||
master_ds = split_dataset_by_node(master_ds, rank, world_size) | ||
|
||
return master_ds | ||
|
@@ -167,13 +183,19 @@ def _convert_to_prompts( | |
num_classes: int, | ||
num_variants: int, | ||
rng: Random, | ||
fewshot_iter: Optional[Iterator[list[dict]]] = None, | ||
) -> dict[str, Any]: | ||
"""Prompt-generating function to pass to `IterableDataset.map`.""" | ||
prompts = [] | ||
templates = list(prompter.templates.values()) | ||
if num_variants < len(templates): | ||
templates = rng.sample(templates, num_variants) | ||
|
||
def qa_cat(q: str, a: str) -> str: | ||
# if the jinja template already adds whitespace, don't add more | ||
sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " " | ||
return f"{q}{sep}{a}" if a and not a.isspace() else q | ||
|
||
new_label = rng.choice([0, 1]) if num_classes > 2 else example[label_column] | ||
|
||
for template in templates: | ||
|
@@ -189,10 +211,16 @@ def _convert_to_prompts( | |
fake_example[label_column] = answer_idx | ||
|
||
q, a = template.apply(fake_example) | ||
text = qa_cat(q, a) | ||
|
||
if fewshot_iter is not None: | ||
# Infinite iterator so we don't need to worry about StopIteration | ||
fewshot_examples = next(fewshot_iter) | ||
fewshot_texts = [ | ||
qa_cat(q, a) for q, a in map(template.apply, fewshot_examples) | ||
] | ||
text = "\n\n".join(fewshot_texts) + "\n\n" + text | ||
|
||
# if the jinja template already adds whitespace, don't add more | ||
sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " " | ||
text = f"{q}{sep}{a}" if a and not a.isspace() else q | ||
choices.append( | ||
dict( | ||
# Strip whitespace from the answer to make it easier to | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why you used @DataClass for the above class, but not here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't had time yet to review the pull request in more detail and to test it. I can do that tomorrow evening, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this is a little bit inconsistent but idk how much it matters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arguably FewShotSampler might be better if we make it not an iterator bc infinite iterators are weird