Skip to content

Commit

Permalink
Merge branch 'main' into eval_dirs
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristyKoh committed Apr 19, 2023
2 parents 74551fa + 0b3d3c9 commit d8cee8b
Show file tree
Hide file tree
Showing 39 changed files with 1,507 additions and 884 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ repos:
hooks:
- id: codespell
# The promptsource templates spuriously get flagged without this
args: ["--skip=*.yaml"]
args: ["-L fpr", "--skip=*.yaml"]
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ The following command will evaluate the probe from the run naughty-northcutt on
elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb
```

The following runs `elicit` on the Cartesian product of the listed models and datasets, storing it in a special folder ELK_DIR/sweeps/<memorable_name>. Moreover, `--add_pooled` adds an additional dataset that pools all of the datasets together.

```bash
elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled
```

## Caching

The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`.
Expand Down
4 changes: 2 additions & 2 deletions elk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from simple_parsing import ArgumentParser

from elk.evaluation.evaluate import Eval
from elk.extraction.extraction import Extract
from elk.training.sweep import Sweep
from elk.training.train import Elicit


@dataclass
class Command:
"""Some top-level command"""

command: Elicit | Eval | Extract
command: Elicit | Eval | Sweep

def execute(self):
return self.command.execute()
Expand Down
4 changes: 2 additions & 2 deletions elk/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def update(self, labels: Tensor, probs: Tensor) -> "CalibrationError":
assert labels.shape == probs.shape
assert torch.is_floating_point(probs)

self.labels.append(probs)
self.pred_probs.append(labels)
self.labels.append(labels)
self.pred_probs.append(probs)
return self

def compute(self, p: int = 2) -> CalibrationEstimate:
Expand Down
57 changes: 57 additions & 0 deletions elk/debug_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
from pathlib import Path

from datasets import DatasetDict

from .utils import get_dataset_name, select_train_val_splits


def save_debug_log(datasets: list[DatasetDict], out_dir: Path) -> None:
"""
Save a debug log to the output directory. This is useful for debugging
training issues.
"""

logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(levelname)s:\n%(message)s",
filename=out_dir / "debug.log",
filemode="w",
)

for ds in datasets:
logging.info(
"=========================================\n"
f"Dataset: {get_dataset_name(ds)}\n"
"========================================="
)

train_split, val_split = select_train_val_splits(ds)
text_inputs = ds[val_split][0]["text_inputs"]
template_ids = ds[val_split][0]["variant_ids"]
label = ds[val_split][0]["label"]

# log the train size and val size
logging.info(f"Train size: {len(ds[train_split])}")
logging.info(f"Val size: {len(ds[val_split])}")

templates_text = f"{len(text_inputs)} templates used:\n"
trailing_whitespace = False
for (text0, text1), id in zip(text_inputs, template_ids):
templates_text += (
f'***---TEMPLATE "{id}"---***\n'
f"{'false' if label else 'true'}:\n"
f'"""{text0}"""\n'
f"{'true' if label else 'false'}:\n"
f'"""{text1}"""\n\n\n'
)
if text0[-1].isspace() or text1[-1].isspace():
trailing_whitespace = True
if trailing_whitespace:
logging.warning(
"Some inputs to the model have trailing whitespace! "
"Check that the jinja templates are not adding "
"trailing whitespace. If `token_loc` is 'last', this "
"will extract hidden states from the whitespace token."
)
logging.info(templates_text)
69 changes: 33 additions & 36 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Literal, Optional
from typing import Callable

import pandas as pd
import torch
Expand All @@ -11,7 +11,7 @@
from ..files import elk_reporter_dir
from ..run import Run
from ..training import Reporter
from ..training.baseline import evaluate_baseline, load_baseline
from ..training.supervised import evaluate_supervised
from ..utils import select_usable_devices


Expand All @@ -28,19 +28,20 @@ class Eval(Serializable):
`elk.training.preprocessing.normalize()` for details.
num_gpus: The number of GPUs to use. Defaults to -1, which means
"use all available GPUs".
skip_supervised: Whether to skip evaluation of the supervised classifier.
debug: When in debug mode, a useful log file is saved to the memorably-named
output directory. Defaults to False.
"""

data: Extract
source: str = field(positional=True)
normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly"

concatenated_layer_offset: int = 0
debug: bool = False
out_dir: Optional[Path] = None
min_gpu_mem: int | None = None
num_gpus: int = -1
skip_baseline: bool = False
concatenated_layer_offset: int = 0
out_dir: Path | None = None
skip_supervised: bool = False
combine_evals: bool = False

def execute(self):
Expand All @@ -50,7 +51,6 @@ def execute(self):

if self.combine_evals:
run = Evaluate(cfg=self, out_dir=transfer_dir / ", ".join(datasets))
run.evaluate()
else:
# eval on each dataset separately
for dataset in datasets:
Expand All @@ -65,55 +65,52 @@ class Evaluate(Run):

def evaluate_reporter(
self, layer: int, devices: list[str], world_size: int = 1
) -> pd.Series:
) -> pd.DataFrame:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)

_, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data(
device,
layer,
)
val_output = self.prepare_data(device, layer, "val")

experiment_dir = elk_reporter_dir() / self.cfg.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter.eval()

test_result = reporter.score(
test_labels,
test_x0,
test_x1,
)
row_buf = []
for ds_name, (val_h, val_gt, _) in val_output.items():
val_result = reporter.score(val_gt, val_h)

stats_row = pd.Series(
{
"layer": layer,
**test_result._asdict(),
}
)

lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_baseline and lr_dir.exists():
lr_model = load_baseline(lr_dir, layer)
lr_model.eval()
lr_auroc, lr_acc = evaluate_baseline(
lr_model.cuda(), test_x0.cuda(), test_x1.cuda(), test_labels
stats_row = pd.Series(
{
"dataset": ds_name,
"layer": layer,
**val_result._asdict(),
}
)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc
lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_auroc_res, lr_acc = evaluate_supervised(lr_model, val_h, val_gt)
stats_row["lr_auroc"] = lr_auroc_res.estimate
stats_row["lr_auroc_lower"] = lr_auroc_res.lower
stats_row["lr_auroc_upper"] = lr_auroc_res.upper
stats_row["lr_acc"] = lr_acc

row_buf.append(stats_row)

return stats_row
return pd.DataFrame(row_buf)

def evaluate(self):
"""Evaluate the reporter on all layers."""
devices = select_usable_devices(
self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem
self.cfg.num_gpus, min_memory=self.cfg.min_gpu_mem
)

num_devices = len(devices)
func: Callable[[int], pd.Series] = partial(
func: Callable[[int], pd.DataFrame] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)
58 changes: 34 additions & 24 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,59 @@
from collections import deque
from dataclasses import dataclass, field
from itertools import cycle
from random import Random
from typing import Iterable, Iterator, Optional

from datasets import Features, IterableDataset
from torch.utils.data import IterableDataset as TorchIterableDataset

from ..math_util import stochastic_round_constrained
from ..utils import infer_label_column
from ..utils.math_util import stochastic_round_constrained
from ..utils.typing import assert_type


@dataclass
class BalancedSampler(TorchIterableDataset):
"""
Approximately balances a binary classification dataset in a streaming fashion.
Args:
dataset (IterableDataset): The HuggingFace IterableDataset to balance.
label_col (Optional[str], optional): The name of the column containing the
binary label. If not provided, the label column will be inferred from
the dataset features. Defaults to None.
buffer_size (int, optional): The total buffer size to use for balancing the
dataset. This value should be divisible by 2, as it will be equally
divided between the two binary label values (0 and 1). Defaults to 1000.
A sampler that approximately balances a multi-class classification dataset in a
streaming fashion.
Attributes:
data: The input dataset to balance.
num_classes: The total number of classes expected in the data.
buffer_size: The total buffer size to use for balancing the dataset. Each class
will have its own buffer with this size.
"""

def __init__(self, data: Iterable[dict], buffer_size: int = 1000):
self.data = data
data: Iterable[dict]
num_classes: int
buffer_size: int = 1000
buffers: dict[int, deque[dict]] = field(default_factory=dict, init=False)
label_col: str = "label"

self.neg_buffer = deque(maxlen=buffer_size)
self.pos_buffer = deque(maxlen=buffer_size)
def __post_init__(self):
# Initialize empty buffers
self.buffers = {
label: deque(maxlen=self.buffer_size) for label in range(self.num_classes)
}

def __iter__(self):
for sample in self.data:
label = sample["label"]
label = sample[self.label_col]

# Add the sample to the appropriate buffer
if label == 0:
self.neg_buffer.append(sample)
else:
self.pos_buffer.append(sample)
# This whole class is a no-op if the label is not an integer
if not isinstance(label, int):
yield sample
continue

# Add the sample to the buffer for its class label
self.buffers[label].append(sample)

while self.neg_buffer and self.pos_buffer:
yield self.neg_buffer.popleft()
yield self.pos_buffer.popleft()
# Check if all buffers have at least one sample
while all(len(buffer) > 0 for buffer in self.buffers.values()):
# Yield one sample from each buffer in a round-robin fashion
for buf in self.buffers.values():
yield buf.popleft()


class FewShotSampler:
Expand Down

0 comments on commit d8cee8b

Please sign in to comment.