Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple choice datasets #179

Merged
merged 54 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
d292c7c
LM output evaluation for autoregressive models
norabelrose Apr 4, 2023
7ed5ccd
move to own baseline file
lauritowal Apr 4, 2023
ba1d3b2
cleanup
lauritowal Apr 4, 2023
a20d4ca
Support encoder-decoder model LM output
norabelrose Apr 5, 2023
088758e
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 5, 2023
77d7418
isort
norabelrose Apr 5, 2023
5bf63f4
Bug fixes
norabelrose Apr 5, 2023
819cfed
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
d3d9a8d
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
b89e23c
Remove test_log_csv_elements
norabelrose Apr 5, 2023
9aef842
Remove Python 3.9 support
norabelrose Apr 5, 2023
0851d4f
Add Pandas to pyproject.toml
norabelrose Apr 5, 2023
207a375
add code (contains still same device cuda error)
lauritowal Apr 5, 2023
e7efcce
fix multiple cuda error, save evals to right folder + cleanup
lauritowal Apr 7, 2023
b5fa54c
Merge branch 'main' into eval_lr
lauritowal Apr 7, 2023
4f8bdc5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
9ca72ba
Fix bug noticed by Waree
norabelrose Apr 7, 2023
d7e4893
Merge remote-tracking branch 'origin/eval_lr' into lm-output
norabelrose Apr 7, 2023
bcdca8a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 7, 2023
713a251
Add sanity check to load_prompts and refactor binarize
norabelrose Apr 7, 2023
0c35bc7
Changing a ton of stuff
norabelrose Apr 8, 2023
f6a762a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f547744
Revert changes to binarize
norabelrose Apr 10, 2023
ab1909f
Stupid prompt_counter bug
norabelrose Apr 10, 2023
f58290f
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f912ee6
Remove stupid second set_start_method call
norabelrose Apr 10, 2023
606dcad
Merge remote-tracking branch 'origin/lm-output' into multiclass
norabelrose Apr 10, 2023
0038792
Merge remote-tracking branch 'origin/main' into multiclass
norabelrose Apr 10, 2023
83b480b
Fix bugs in binary case
norabelrose Apr 11, 2023
3e66262
Various little refactors
norabelrose Apr 11, 2023
a8c21a6
Remove .predict and .predict_prob on Reporter; trying to get SciQ to …
norabelrose Apr 11, 2023
5f478b1
Bugfix for Reporter.score on binary tasks
norabelrose Apr 11, 2023
97b26ac
Fix bug where cached hidden states aren’t used when num_gpus is diffe…
norabelrose Apr 12, 2023
11fda87
Actually works now
norabelrose Apr 12, 2023
da4c72f
Refactor handling of multiple datasets
norabelrose Apr 13, 2023
e1675f7
Various fixes
norabelrose Apr 13, 2023
8cc325b
Merge remote-tracking branch 'origin/main' into multi-ds-eval
norabelrose Apr 13, 2023
14987e1
Fix math tests
norabelrose Apr 13, 2023
88683fa
Fix smoke tests
norabelrose Apr 13, 2023
a6c382e
All tests working ostensibly
norabelrose Apr 13, 2023
ecc53cb
Make CCS normalization customizable
norabelrose Apr 13, 2023
18c7f4c
log each dataset individually
AlexTMallen Apr 13, 2023
94a900c
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 13, 2023
5173649
Fix label_column bug
norabelrose Apr 13, 2023
3e6c39c
GLUE MNLI works on Deberta
norabelrose Apr 14, 2023
1e9ce06
Move pseudo AUROC stuff to CcsReporter
norabelrose Apr 14, 2023
35a8f34
Make 'datasets' and 'label_columns' config options more opinionated
norabelrose Apr 14, 2023
615bbb1
tiny spacing change
norabelrose Apr 14, 2023
f021404
Allow for toggling CV
norabelrose Apr 14, 2023
f6629ec
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 14, 2023
99f01c3
Remove duplicate dbpedia template
norabelrose Apr 14, 2023
f415f8d
Merge branch 'main' into multiclass
norabelrose Apr 14, 2023
d16c96b
Training on datasets with different numbers of classes now works
norabelrose Apr 15, 2023
6e205a7
Update extraction.py
lauritowal Apr 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,8 @@ def evaluate_reporter(
reporter.eval()

row_buf = []
for ds_name, (val_x0, val_x1, val_gt, _) in val_output.items():
val_result = reporter.score(
val_gt,
val_x0,
val_x1,
)
for ds_name, (val_h, val_gt, _) in val_output.items():
val_result = reporter.score(val_gt, val_h)

stats_row = pd.Series(
{
Expand All @@ -89,7 +85,7 @@ def evaluate_reporter(
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_auroc, lr_acc = evaluate_supervised(lr_model, val_x0, val_x1, val_gt)
lr_auroc, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc
Expand Down
56 changes: 33 additions & 23 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import deque
from dataclasses import dataclass, field
from itertools import cycle
from random import Random
from typing import Iterable, Iterator, Optional
Expand All @@ -11,39 +12,48 @@
from ..utils.typing import assert_type


@dataclass
class BalancedSampler(TorchIterableDataset):
"""
Approximately balances a binary classification dataset in a streaming fashion.

Args:
dataset (IterableDataset): The HuggingFace IterableDataset to balance.
label_col (Optional[str], optional): The name of the column containing the
binary label. If not provided, the label column will be inferred from
the dataset features. Defaults to None.
buffer_size (int, optional): The total buffer size to use for balancing the
dataset. This value should be divisible by 2, as it will be equally
divided between the two binary label values (0 and 1). Defaults to 1000.
A sampler that approximately balances a multi-class classification dataset in a
streaming fashion.

Attributes:
data: The input dataset to balance.
num_classes: The total number of classes expected in the data.
buffer_size: The total buffer size to use for balancing the dataset. Each class
will have its own buffer with this size.
"""

def __init__(self, data: Iterable[dict], buffer_size: int = 1000):
self.data = data
data: Iterable[dict]
num_classes: int
buffer_size: int = 1000
buffers: dict[int, deque[dict]] = field(default_factory=dict, init=False)
label_col: str = "label"

self.neg_buffer = deque(maxlen=buffer_size)
self.pos_buffer = deque(maxlen=buffer_size)
def __post_init__(self):
# Initialize empty buffers
self.buffers = {
label: deque(maxlen=self.buffer_size) for label in range(self.num_classes)
}

def __iter__(self):
for sample in self.data:
label = sample["label"]
label = sample[self.label_col]

# Add the sample to the appropriate buffer
if label == 0:
self.neg_buffer.append(sample)
else:
self.pos_buffer.append(sample)
# This whole class is a no-op if the label is not an integer
if not isinstance(label, int):
yield sample
continue

# Add the sample to the buffer for its class label
self.buffers[label].append(sample)

while self.neg_buffer and self.pos_buffer:
yield self.neg_buffer.popleft()
yield self.pos_buffer.popleft()
# Check if all buffers have at least one sample
while all(len(buffer) > 0 for buffer in self.buffers.values()):
# Yield one sample from each buffer in a round-robin fashion
for buf in self.buffers.values():
yield buf.popleft()


class FewShotSampler:
Expand Down
43 changes: 27 additions & 16 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import torch
from datasets import (
Array2D,
Array3D,
ClassLabel,
DatasetDict,
Features,
Sequence,
Expand All @@ -28,12 +28,13 @@
assert_type,
convert_span,
float32_to_int16,
infer_label_column,
infer_num_classes,
instantiate_model,
is_autoregressive,
select_train_val_splits,
select_usable_devices,
)
from .balanced_sampler import BalancedSampler
from .generator import _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts

Expand Down Expand Up @@ -100,13 +101,16 @@ def extract_hiddens(
if rank != 0:
logging.disable(logging.CRITICAL)

ds_names = cfg.prompts.datasets
p_cfg = cfg.prompts
ds_names = p_cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

prompt_ds = load_prompts(
ds_names[0],
label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None,
num_classes=p_cfg.num_classes,
split_type=split_type,
stream=cfg.prompts.stream,
stream=p_cfg.stream,
rank=rank,
world_size=world_size,
) # this dataset is already sharded, but hasn't been truncated to max_examples
Expand All @@ -124,19 +128,21 @@ def extract_hiddens(
# Iterating over questions
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))

global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
# the last process gets the remainder (which is usually small)
if rank == world_size - 1:
max_examples += global_max_examples % world_size

for example in islice(BalancedSampler(prompt_ds), max_examples):
for example in islice(prompt_ds, max_examples):
num_variants = len(example["prompts"])
num_choices = len(example["prompts"][0])

hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
num_variants,
2, # contrast pair
num_choices,
model.config.hidden_size,
device=device,
dtype=torch.int16,
Expand All @@ -145,7 +151,7 @@ def extract_hiddens(
}
lm_preds = torch.empty(
num_variants,
2, # contrast pair
num_choices,
device=device,
dtype=torch.float32,
)
Expand Down Expand Up @@ -238,8 +244,7 @@ def extract_hiddens(
**hidden_dict,
)
if has_lm_preds:
# We only need the probability of the positive example since this is binary
out_record["model_preds"] = lm_preds.softmax(dim=-1)[..., 1]
out_record["model_preds"] = lm_preds.softmax(dim=-1)

yield out_record

Expand Down Expand Up @@ -281,6 +286,13 @@ def get_splits() -> SplitDict:
ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)

ds_features = assert_type(Features, info.features)
label_col = (
cfg.prompts.label_columns[0]
if cfg.prompts.label_columns
else infer_label_column(ds_features)
)
num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col])
num_variants = cfg.prompts.num_variants
if num_variants < 0:
prompter = DatasetTemplates(ds_name, config_name)
Expand All @@ -289,7 +301,7 @@ def get_splits() -> SplitDict:
layer_cols = {
f"hidden_{layer}": Array3D(
dtype="int16",
shape=(num_variants, 2, model_cfg.hidden_size),
shape=(num_variants, num_classes, model_cfg.hidden_size),
)
for layer in cfg.layers or range(model_cfg.num_hidden_layers)
}
Expand All @@ -298,21 +310,20 @@ def get_splits() -> SplitDict:
Value(dtype="string"),
length=num_variants,
),
"label": ClassLabel(names=["neg", "pos"]),
"label": Value(dtype="int64"),
"text_inputs": Sequence(
Sequence(
Value(dtype="string"),
length=2,
),
length=num_variants,
),
}

# Only add model_preds if the model is an autoregressive model
if is_autoregressive(model_cfg):
other_cols["model_preds"] = Sequence(
Value(dtype="float32"),
length=num_variants,
other_cols["model_preds"] = Array2D(
shape=(num_variants, num_classes),
dtype="float32",
)

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
Expand Down
40 changes: 26 additions & 14 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
binarize,
infer_label_column,
infer_num_classes,
select_train_val_splits,
Expand Down Expand Up @@ -51,6 +50,7 @@ class PromptConfig(Serializable):
data_dirs: list[str] = field(default_factory=list)
label_columns: list[str] = field(default_factory=list)
max_examples: list[int] = field(default_factory=lambda: [750, 250])
num_classes: int = 0
num_shots: int = 0
num_variants: int = -1
seed: int = 42
Expand Down Expand Up @@ -104,6 +104,8 @@ def explode(self) -> list["PromptConfig"]:

def load_prompts(
ds_string: str,
label_column: Optional[str] = None,
num_classes: int = 0,
num_shots: int = 0,
num_variants: int = -1,
seed: int = 42,
Expand Down Expand Up @@ -141,10 +143,12 @@ def load_prompts(
train_ds = ds_dict[train_name].shuffle(seed=seed)
if not stream:
ds = assert_type(Dataset, ds)
if world_size > 1:
ds = ds.shard(world_size, rank)

ds = ds.to_iterable_dataset().cast(ds.features)

# only keep the datapoints relevant to the current process
if world_size > 1:
elif world_size > 1:
# This prints to stdout which is slightly annoying
ds = split_dataset_by_node(dataset=ds, rank=rank, world_size=world_size)

Expand Down Expand Up @@ -199,7 +203,7 @@ def _convert_to_prompts(
fewshot_iter: Optional[Iterator[list[dict]]] = None,
) -> dict[str, Any]:
"""Prompt-generating function to pass to `IterableDataset.map`."""
assert_type(int, example[label_column])
labels_are_strings = isinstance(example[label_column], str)
prompts = []
templates = list(prompter.templates.values())
if num_variants < len(templates):
Expand All @@ -212,22 +216,24 @@ def qa_cat(q: str, a: str) -> str:

# For sanity checking that prompts are unique
prompt_counter = Counter()
new_label = rng.choice([0, 1]) if num_classes > 2 else example[label_column]
label_indices = set()

for template in templates:
choices = []
string_choices = template.get_answer_choices_list(example)

if num_classes > 2:
template = binarize(
template, example[label_column], assert_type(int, new_label), rng
)
label = example[label_column]
label_indices.add(string_choices.index(label) if labels_are_strings else label)

for answer_idx in range(2):
for answer_idx in range(num_classes):
fake_example = example.copy()
fake_example[label_column] = answer_idx
if labels_are_strings:
fake_example[label_column] = string_choices[answer_idx]
else:
fake_example[label_column] = answer_idx

q, a = template.apply(fake_example)
text = qa_cat(q, a)
text = qa_cat(q, a or string_choices[answer_idx])
prompt_counter[text] += 1

if fewshot_iter is not None:
Expand All @@ -254,8 +260,14 @@ def qa_cat(q: str, a: str) -> str:
if dup_count > 1:
raise ValueError(f'Prompt duplicated {dup_count} times! "{maybe_dup}"')

# Sanity check: label should be the same across all variants
if len(label_indices) > 1:
raise ValueError(
f"Label index should be the same all variants, but got {label_indices}"
)

return dict(
label=new_label,
label=label_indices.pop(),
prompts=prompts,
template_names=prompter.all_template_names,
template_names=[template.name for template in templates],
)
36 changes: 36 additions & 0 deletions elk/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from torch import Tensor


def to_one_hot(labels: Tensor, n_classes: int) -> Tensor:
"""
Convert a tensor of class labels to a one-hot representation.

Args:
labels (Tensor): A tensor of class labels of shape (N,).
n_classes (int): The total number of unique classes.

Returns:
Tensor: A one-hot representation tensor of shape (N, n_classes).
"""
one_hot_labels = labels.new_zeros(labels.size(0), n_classes)
return one_hot_labels.scatter_(1, labels.unsqueeze(1).long(), 1)


def accuracy(y_true: Tensor, y_pred: Tensor) -> float:
"""
Compute the accuracy of a classification model.

Args:
y_true: Ground truth tensor of shape (N,).
y_pred: Predicted class tensor of shape (N,) or (N, n_classes).

Returns:
float: Accuracy of the model.
"""
# Check if binary or multi-class classification
if len(y_pred.shape) == 1:
hard_preds = y_pred > 0.5
else:
hard_preds = y_pred.argmax(-1)

return hard_preds.cpu().eq(y_true.cpu()).float().mean().item()
Loading