diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 38fff767..42e9a5f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,4 +24,4 @@ repos: hooks: - id: codespell # The promptsource templates spuriously get flagged without this - args: ["--skip=*.yaml"] + args: ["-L fpr", "--skip=*.yaml"] diff --git a/README.md b/README.md index 96d51ee1..f7165573 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,12 @@ The following command will evaluate the probe from the run naughty-northcutt on elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb ``` +The following runs `elicit` on the Cartesian product of the listed models and datasets, storing it in a special folder ELK_DIR/sweeps/. Moreover, `--add_pooled` adds an additional dataset that pools all of the datasets together. + +```bash +elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled +``` + ## Caching The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`. diff --git a/elk/__main__.py b/elk/__main__.py index 5304f5aa..a477e799 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -5,7 +5,7 @@ from simple_parsing import ArgumentParser from elk.evaluation.evaluate import Eval -from elk.extraction.extraction import Extract +from elk.training.sweep import Sweep from elk.training.train import Elicit @@ -13,7 +13,7 @@ class Command: """Some top-level command""" - command: Elicit | Eval | Extract + command: Elicit | Eval | Sweep def execute(self): return self.command.execute() diff --git a/elk/calibration.py b/elk/calibration.py index db56fa02..23a48b8e 100644 --- a/elk/calibration.py +++ b/elk/calibration.py @@ -34,8 +34,8 @@ def update(self, labels: Tensor, probs: Tensor) -> "CalibrationError": assert labels.shape == probs.shape assert torch.is_floating_point(probs) - self.labels.append(probs) - self.pred_probs.append(labels) + self.labels.append(labels) + self.pred_probs.append(probs) return self def compute(self, p: int = 2) -> CalibrationEstimate: diff --git a/elk/debug_logging.py b/elk/debug_logging.py new file mode 100644 index 00000000..b43650ab --- /dev/null +++ b/elk/debug_logging.py @@ -0,0 +1,57 @@ +import logging +from pathlib import Path + +from datasets import DatasetDict + +from .utils import get_dataset_name, select_train_val_splits + + +def save_debug_log(datasets: list[DatasetDict], out_dir: Path) -> None: + """ + Save a debug log to the output directory. This is useful for debugging + training issues. + """ + + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(levelname)s:\n%(message)s", + filename=out_dir / "debug.log", + filemode="w", + ) + + for ds in datasets: + logging.info( + "=========================================\n" + f"Dataset: {get_dataset_name(ds)}\n" + "=========================================" + ) + + train_split, val_split = select_train_val_splits(ds) + text_inputs = ds[val_split][0]["text_inputs"] + template_ids = ds[val_split][0]["variant_ids"] + label = ds[val_split][0]["label"] + + # log the train size and val size + logging.info(f"Train size: {len(ds[train_split])}") + logging.info(f"Val size: {len(ds[val_split])}") + + templates_text = f"{len(text_inputs)} templates used:\n" + trailing_whitespace = False + for (text0, text1), id in zip(text_inputs, template_ids): + templates_text += ( + f'***---TEMPLATE "{id}"---***\n' + f"{'false' if label else 'true'}:\n" + f'"""{text0}"""\n' + f"{'true' if label else 'false'}:\n" + f'"""{text1}"""\n\n\n' + ) + if text0[-1].isspace() or text1[-1].isspace(): + trailing_whitespace = True + if trailing_whitespace: + logging.warning( + "Some inputs to the model have trailing whitespace! " + "Check that the jinja templates are not adding " + "trailing whitespace. If `token_loc` is 'last', this " + "will extract hidden states from the whitespace token." + ) + logging.info(templates_text) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 9cfdd6eb..8b01eb89 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Literal, Optional +from typing import Callable import pandas as pd import torch @@ -11,7 +11,7 @@ from ..files import elk_reporter_dir from ..run import Run from ..training import Reporter -from ..training.baseline import evaluate_baseline, load_baseline +from ..training.supervised import evaluate_supervised from ..utils import select_usable_devices @@ -28,19 +28,20 @@ class Eval(Serializable): `elk.training.preprocessing.normalize()` for details. num_gpus: The number of GPUs to use. Defaults to -1, which means "use all available GPUs". + skip_supervised: Whether to skip evaluation of the supervised classifier. debug: When in debug mode, a useful log file is saved to the memorably-named output directory. Defaults to False. """ data: Extract source: str = field(positional=True) - normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" + concatenated_layer_offset: int = 0 debug: bool = False - out_dir: Optional[Path] = None + min_gpu_mem: int | None = None num_gpus: int = -1 - skip_baseline: bool = False - concatenated_layer_offset: int = 0 + out_dir: Path | None = None + skip_supervised: bool = False combine_evals: bool = False def execute(self): @@ -50,7 +51,6 @@ def execute(self): if self.combine_evals: run = Evaluate(cfg=self, out_dir=transfer_dir / ", ".join(datasets)) - run.evaluate() else: # eval on each dataset separately for dataset in datasets: @@ -65,14 +65,10 @@ class Evaluate(Run): def evaluate_reporter( self, layer: int, devices: list[str], world_size: int = 1 - ) -> pd.Series: + ) -> pd.DataFrame: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) - - _, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data( - device, - layer, - ) + val_output = self.prepare_data(device, layer, "val") experiment_dir = elk_reporter_dir() / self.cfg.source @@ -80,40 +76,41 @@ def evaluate_reporter( reporter: Reporter = torch.load(reporter_path, map_location=device) reporter.eval() - test_result = reporter.score( - test_labels, - test_x0, - test_x1, - ) + row_buf = [] + for ds_name, (val_h, val_gt, _) in val_output.items(): + val_result = reporter.score(val_gt, val_h) - stats_row = pd.Series( - { - "layer": layer, - **test_result._asdict(), - } - ) - - lr_dir = experiment_dir / "lr_models" - if not self.cfg.skip_baseline and lr_dir.exists(): - lr_model = load_baseline(lr_dir, layer) - lr_model.eval() - lr_auroc, lr_acc = evaluate_baseline( - lr_model.cuda(), test_x0.cuda(), test_x1.cuda(), test_labels + stats_row = pd.Series( + { + "dataset": ds_name, + "layer": layer, + **val_result._asdict(), + } ) - stats_row["lr_auroc"] = lr_auroc - stats_row["lr_acc"] = lr_acc + lr_dir = experiment_dir / "lr_models" + if not self.cfg.skip_supervised and lr_dir.exists(): + with open(lr_dir / f"layer_{layer}.pt", "rb") as f: + lr_model = torch.load(f, map_location=device).eval() + + lr_auroc_res, lr_acc = evaluate_supervised(lr_model, val_h, val_gt) + stats_row["lr_auroc"] = lr_auroc_res.estimate + stats_row["lr_auroc_lower"] = lr_auroc_res.lower + stats_row["lr_auroc_upper"] = lr_auroc_res.upper + stats_row["lr_acc"] = lr_acc + + row_buf.append(stats_row) - return stats_row + return pd.DataFrame(row_buf) def evaluate(self): """Evaluate the reporter on all layers.""" devices = select_usable_devices( - self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem + self.cfg.num_gpus, min_memory=self.cfg.min_gpu_mem ) num_devices = len(devices) - func: Callable[[int], pd.Series] = partial( + func: Callable[[int], pd.DataFrame] = partial( self.evaluate_reporter, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 2ea4815e..3fd8a739 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -1,4 +1,5 @@ from collections import deque +from dataclasses import dataclass, field from itertools import cycle from random import Random from typing import Iterable, Iterator, Optional @@ -6,44 +7,53 @@ from datasets import Features, IterableDataset from torch.utils.data import IterableDataset as TorchIterableDataset -from ..math_util import stochastic_round_constrained from ..utils import infer_label_column +from ..utils.math_util import stochastic_round_constrained from ..utils.typing import assert_type +@dataclass class BalancedSampler(TorchIterableDataset): """ - Approximately balances a binary classification dataset in a streaming fashion. - - Args: - dataset (IterableDataset): The HuggingFace IterableDataset to balance. - label_col (Optional[str], optional): The name of the column containing the - binary label. If not provided, the label column will be inferred from - the dataset features. Defaults to None. - buffer_size (int, optional): The total buffer size to use for balancing the - dataset. This value should be divisible by 2, as it will be equally - divided between the two binary label values (0 and 1). Defaults to 1000. + A sampler that approximately balances a multi-class classification dataset in a + streaming fashion. + + Attributes: + data: The input dataset to balance. + num_classes: The total number of classes expected in the data. + buffer_size: The total buffer size to use for balancing the dataset. Each class + will have its own buffer with this size. """ - def __init__(self, data: Iterable[dict], buffer_size: int = 1000): - self.data = data + data: Iterable[dict] + num_classes: int + buffer_size: int = 1000 + buffers: dict[int, deque[dict]] = field(default_factory=dict, init=False) + label_col: str = "label" - self.neg_buffer = deque(maxlen=buffer_size) - self.pos_buffer = deque(maxlen=buffer_size) + def __post_init__(self): + # Initialize empty buffers + self.buffers = { + label: deque(maxlen=self.buffer_size) for label in range(self.num_classes) + } def __iter__(self): for sample in self.data: - label = sample["label"] + label = sample[self.label_col] - # Add the sample to the appropriate buffer - if label == 0: - self.neg_buffer.append(sample) - else: - self.pos_buffer.append(sample) + # This whole class is a no-op if the label is not an integer + if not isinstance(label, int): + yield sample + continue + + # Add the sample to the buffer for its class label + self.buffers[label].append(sample) - while self.neg_buffer and self.pos_buffer: - yield self.neg_buffer.popleft() - yield self.pos_buffer.popleft() + # Check if all buffers have at least one sample + while all(len(buffer) > 0 for buffer in self.buffers.values()): + # Yield one sample from each buffer in a round-robin fashion + for buf in self.buffers.values(): + yield buf.popleft() class FewShotSampler: diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index a9abd6a6..2f9bda09 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -4,12 +4,12 @@ from copy import copy from dataclasses import InitVar, dataclass from itertools import islice -from typing import Any, Iterable, Literal, Optional +from typing import Any, Iterable, Literal import torch from datasets import ( + Array2D, Array3D, - ClassLabel, DatasetDict, Features, Sequence, @@ -23,16 +23,18 @@ from transformers import AutoConfig, AutoTokenizer from transformers.modeling_outputs import Seq2SeqLMOutput +from ..promptsource import DatasetTemplates from ..utils import ( assert_type, convert_span, float32_to_int16, + infer_label_column, + infer_num_classes, instantiate_model, is_autoregressive, select_train_val_splits, select_usable_devices, ) -from .balanced_sampler import BalancedSampler from .generator import _GeneratorBuilder from .prompt_loading import PromptConfig, load_prompts @@ -48,7 +50,6 @@ class Extract(Serializable): layer_stride: Shortcut for setting `layers` to `range(0, num_layers, stride)`. token_loc: The location of the token to extract hidden states from. Can be either "first", "last", or "mean". Defaults to "last". - min_gpu_mem: Minimum amount of free memory (in bytes) required to select a GPU. """ prompts: PromptConfig @@ -57,8 +58,6 @@ class Extract(Serializable): layers: tuple[int, ...] = () layer_stride: InitVar[int] = 1 token_loc: Literal["first", "last", "mean"] = "last" - min_gpu_mem: Optional[int] = None - num_gpus: int = -1 def __post_init__(self, layer_stride: int): if self.layers and layer_stride > 1: @@ -74,8 +73,16 @@ def __post_init__(self, layer_stride: int): ) self.layers = tuple(range(0, config.num_hidden_layers, layer_stride)) - def execute(self): - extract(cfg=self, num_gpus=self.num_gpus) + def explode(self) -> list["Extract"]: + """Explode this config into a list of configs, one for each layer.""" + copies = [] + + for prompt_cfg in self.prompts.explode(): + cfg = copy(self) + cfg.prompts = prompt_cfg + copies.append(cfg) + + return copies @torch.no_grad() @@ -94,10 +101,16 @@ def extract_hiddens( if rank != 0: logging.disable(logging.CRITICAL) + p_cfg = cfg.prompts + ds_names = p_cfg.datasets + assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." + prompt_ds = load_prompts( - *cfg.prompts.datasets, + ds_names[0], + label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None, + num_classes=p_cfg.num_classes, split_type=split_type, - stream=cfg.prompts.stream, + stream=p_cfg.stream, rank=rank, world_size=world_size, ) # this dataset is already sharded, but hasn't been truncated to max_examples @@ -115,19 +128,21 @@ def extract_hiddens( # Iterating over questions layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) - global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1] + global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1] # break `max_examples` among the processes roughly equally max_examples = global_max_examples // world_size # the last process gets the remainder (which is usually small) if rank == world_size - 1: max_examples += global_max_examples % world_size - for example in islice(BalancedSampler(prompt_ds), max_examples): + for example in islice(prompt_ds, max_examples): num_variants = len(example["prompts"]) + num_choices = len(example["prompts"][0]) + hidden_dict = { f"hidden_{layer_idx}": torch.empty( num_variants, - 2, # contrast pair + num_choices, model.config.hidden_size, device=device, dtype=torch.int16, @@ -136,7 +151,7 @@ def extract_hiddens( } lm_preds = torch.empty( num_variants, - 2, # contrast pair + num_choices, device=device, dtype=torch.float32, ) @@ -229,8 +244,7 @@ def extract_hiddens( **hidden_dict, ) if has_lm_preds: - # We only need the probability of the positive example since this is binary - out_record["model_preds"] = lm_preds.softmax(dim=-1)[..., 1] + out_record["model_preds"] = lm_preds.softmax(dim=-1) yield out_record @@ -240,14 +254,19 @@ def _extraction_worker(**kwargs): yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) -def extract(cfg: "Extract", num_gpus: int = -1) -> DatasetDict: +def extract( + cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None +) -> DatasetDict: """Extract hidden states from a model and return a `DatasetDict` containing them.""" def get_splits() -> SplitDict: available_splits = assert_type(SplitDict, info.splits) train_name, val_name = select_train_val_splits(available_splits) - print(f"Using '{train_name}' for training and '{val_name}' for validation") - + print( + # Cyan color for dataset name + f"\033[36m{info.builder_name}\033[0m: using '{train_name}' for training and" + f" '{val_name}' for validation" + ) limit_list = cfg.prompts.max_examples return SplitDict( @@ -263,15 +282,26 @@ def get_splits() -> SplitDict: ) model_cfg = AutoConfig.from_pretrained(cfg.model) - num_variants = cfg.prompts.num_variants ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") info = get_dataset_config_info(ds_name, config_name or None) + ds_features = assert_type(Features, info.features) + label_col = ( + cfg.prompts.label_columns[0] + if cfg.prompts.label_columns + else infer_label_column(ds_features) + ) + num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col]) + num_variants = cfg.prompts.num_variants + if num_variants < 0: + prompter = DatasetTemplates(ds_name, config_name) + num_variants = len(prompter.templates) + layer_cols = { f"hidden_{layer}": Array3D( dtype="int16", - shape=(num_variants, 2, model_cfg.hidden_size), + shape=(num_variants, num_classes, model_cfg.hidden_size), ) for layer in cfg.layers or range(model_cfg.num_hidden_layers) } @@ -280,11 +310,10 @@ def get_splits() -> SplitDict: Value(dtype="string"), length=num_variants, ), - "label": ClassLabel(names=["neg", "pos"]), + "label": Value(dtype="int64"), "text_inputs": Sequence( Sequence( Value(dtype="string"), - length=2, ), length=num_variants, ), @@ -292,27 +321,23 @@ def get_splits() -> SplitDict: # Only add model_preds if the model is an autoregressive model if is_autoregressive(model_cfg): - other_cols["model_preds"] = Sequence( - Value(dtype="float32"), - length=num_variants, + other_cols["model_preds"] = Array2D( + shape=(num_variants, num_classes), + dtype="float32", ) - devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem) - - # Prevent the GPU-related config options from invalidating the cache - _cfg = copy(cfg) - _cfg.min_gpu_mem = None - _cfg.num_gpus = -1 - + devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem) builders = { split_name: _GeneratorBuilder( + builder_name=info.builder_name, + config_name=info.config_name, cache_dir=None, features=Features({**layer_cols, **other_cols}), generator=_extraction_worker, split_name=split_name, split_info=split_info, gen_kwargs=dict( - cfg=[_cfg] * len(devices), + cfg=[cfg] * len(devices), device=devices, rank=list(range(len(devices))), split_type=[split_name] * len(devices), diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index e3cad0e5..86e65e08 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -1,17 +1,22 @@ from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional - -import datasets -from datasets import Features +from typing import Any, Callable + +from datasets import ( + BuilderConfig, + DatasetInfo, + Features, + GeneratorBasedBuilder, + SplitInfo, +) from datasets.splits import NamedSplit @dataclass -class _GeneratorConfig(datasets.BuilderConfig): - generator: Optional[Callable] = None +class _GeneratorConfig(BuilderConfig): + generator: Callable | None = None gen_kwargs: dict[str, Any] = field(default_factory=dict) - features: Optional[datasets.Features] = None + features: Features | None = None def create_config_id( self, config_kwargs: dict, custom_features: Features | None @@ -37,28 +42,41 @@ class _SplitGenerator: """ name: str - split_info: datasets.SplitInfo - gen_kwargs: Dict = field(default_factory=dict) + split_info: SplitInfo + gen_kwargs: dict = field(default_factory=dict) def __post_init__(self): self.name = str(self.name) # Make sure we convert NamedSplits in strings NamedSplit(self.name) # check that it's a valid split name -class _GeneratorBuilder(datasets.GeneratorBasedBuilder): +class _GeneratorBuilder(GeneratorBasedBuilder): """Patched version of `datasets.Generator` allowing for splits besides `train`""" BUILDER_CONFIG_CLASS = _GeneratorConfig config: _GeneratorConfig - def __init__(self, split_name: str, split_info: datasets.SplitInfo, **kwargs): + def __init__( + self, + builder_name: str | None, + config_name: str | None, + split_name: str, + split_info: SplitInfo, + **kwargs, + ): self.split_name = split_name self.split_info = split_info super().__init__(**kwargs) + # Weirdly we need to set DatasetInfo.builder_name and DatasetInfo.config_name + # here, not in _info, because super().__init__ modifies them + self.info.builder_name = builder_name + self.info.config_name = config_name + def _info(self): - return datasets.DatasetInfo(features=self.config.features) + # Use the same builder and config name as the original builder + return DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): return [ diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 2250a3c2..b4a94617 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -1,5 +1,7 @@ from collections import Counter +from copy import deepcopy from dataclasses import dataclass +from itertools import zip_longest from random import Random from typing import Any, Iterator, Literal, Optional @@ -14,21 +16,19 @@ from ..promptsource import DatasetTemplates from ..utils import ( assert_type, - binarize, infer_label_column, infer_num_classes, select_train_val_splits, ) -from .balanced_sampler import FewShotSampler +from .balanced_sampler import BalancedSampler, FewShotSampler @dataclass class PromptConfig(Serializable): """ Args: - dataset: Space-delimited name of the HuggingFace dataset to use, e.g. + dataset: List of space-delimited names of the HuggingFace dataset to use, e.g. `"super_glue boolq"` or `"imdb"`. - balance: Whether to force class balance in the dataset using undersampling. data_dir: The directory to use for caching the dataset. Defaults to `~/.cache/huggingface/datasets`. label_column: The column containing the labels. By default, we infer this from @@ -47,10 +47,10 @@ class PromptConfig(Serializable): """ datasets: list[str] = field(positional=True) - balance: bool = False - data_dir: Optional[str] = None - label_column: Optional[str] = None + data_dirs: list[str] = field(default_factory=list) + label_columns: list[str] = field(default_factory=list) max_examples: list[int] = field(default_factory=lambda: [750, 250]) + num_classes: int = 0 num_shots: int = 0 num_variants: int = -1 seed: int = 42 @@ -69,9 +69,43 @@ def __post_init__(self): if len(self.max_examples) == 1: self.max_examples *= 2 + # Broadcast the dataset name to all data_dirs and label_columns + if len(self.data_dirs) == 1: + self.data_dirs *= len(self.datasets) + elif self.data_dirs and len(self.data_dirs) != len(self.datasets): + raise ValueError( + "data_dirs should be a list of length 0, 1, or len(datasets)," + f" but got {len(self.data_dirs)}" + ) + + if len(self.label_columns) == 1: + self.label_columns *= len(self.datasets) + elif self.label_columns and len(self.label_columns) != len(self.datasets): + raise ValueError( + "label_columns should be a list of length 0, 1, or len(datasets)," + f" but got {len(self.label_columns)}" + ) + + def explode(self) -> list["PromptConfig"]: + """Explode the config into a list of configs, one for each dataset.""" + copies = [] + + for ds, data_dir, col in zip_longest( + self.datasets, self.data_dirs, self.label_columns + ): + copy = deepcopy(self) + copy.datasets = [ds] + copy.data_dirs = [data_dir] if data_dir else [] + copy.label_columns = [col] if col else [] + copies.append(copy) + + return copies + def load_prompts( - *dataset_strings: str, + ds_string: str, + label_column: Optional[str] = None, + num_classes: int = 0, num_shots: int = 0, num_variants: int = -1, seed: int = 42, @@ -80,10 +114,10 @@ def load_prompts( rank: int = 0, world_size: int = 1, ) -> Iterator[dict]: - """Load a dataset full of prompts generated from the specified datasets. + """Load a dataset full of prompts generated from the specified dataset. Args: - dataset_strings: Space-delimited names of the HuggingFace datasets to use, + ds_string: Space-delimited name of the HuggingFace dataset to use, e.g. `"super_glue boolq"` or `"imdb"`. num_shots: The number of examples to use in few-shot prompts. If zero, prompts are zero-shot. @@ -94,103 +128,66 @@ def load_prompts( world_size: The number of processes. Defaults to 1. Returns: - An iterable dataset of prompts. + An iterable of prompt dictionaries. """ - prompters = [] - raw_datasets = [] - train_datasets = [] - rng = Random(seed) + ds_name, _, config_name = ds_string.partition(" ") + prompter = DatasetTemplates(ds_name, config_name) - # First load the datasets and prompters. We need to know the minimum number of - # templates for any dataset in order to make sure we don't run out of prompts. - for ds_string in dataset_strings: - ds_name, _, config_name = ds_string.partition(" ") - prompters.append(DatasetTemplates(ds_name, config_name)) + ds_dict = assert_type( + dict, load_dataset(ds_name, config_name or None, streaming=stream) + ) + train_name, val_name = select_train_val_splits(ds_dict) + split_name = val_name if split_type == "val" else train_name - ds_dict = assert_type( - dict, load_dataset(ds_name, config_name or None, streaming=stream) - ) - train_name, val_name = select_train_val_splits(ds_dict) - split_name = val_name if split_type == "val" else train_name - - # Note that when streaming we can only approximately shuffle the dataset - # using a buffer. Streaming shuffling is NOT an adequate shuffle for - # datasets like IMDB, which are sorted by label. - bad_streaming_datasets = ["imdb"] - assert not ( - stream and ds_name in bad_streaming_datasets - ), f"Streaming is not supported for {ds_name}." - split = ds_dict[split_name].shuffle(seed=seed) - train_ds = ds_dict[train_name].shuffle(seed=seed) - if not stream: - split = assert_type(Dataset, split) - split = split.to_iterable_dataset().cast(split.features) - - # only keep the datapoints relevant to the current process + ds = ds_dict[split_name].shuffle(seed=seed) + train_ds = ds_dict[train_name].shuffle(seed=seed) + if not stream: + ds = assert_type(Dataset, ds) if world_size > 1: - # This prints to stdout which is slightly annoying - split = split_dataset_by_node( - dataset=split, rank=rank, world_size=world_size - ) + ds = ds.shard(world_size, rank) - raw_datasets.append(split) - train_datasets.append(train_ds) + ds = ds.to_iterable_dataset().cast(ds.features) - min_num_templates = min(len(prompter.templates) for prompter in prompters) + elif world_size > 1: + # This prints to stdout which is slightly annoying + ds = split_dataset_by_node(dataset=ds, rank=rank, world_size=world_size) + + num_templates = len(prompter.templates) num_variants = ( - min_num_templates - if num_variants == -1 - else min(num_variants, min_num_templates) + num_templates if num_variants == -1 else min(num_variants, num_templates) ) assert num_variants > 0 if rank == 0: print(f"Using {num_variants} variants of each prompt") - ds_iterators = [iter(ds) for ds in raw_datasets] - while True: # terminates when the first dataset runs out of examples - for ds_iterator, ds, train_ds, prompter in zip( - ds_iterators, raw_datasets, train_datasets, prompters - ): - label_column = infer_label_column(ds.features) - num_classes = infer_num_classes(ds.features[label_column]) - - # Remove everything except the label column - extra_cols = list(assert_type(Features, ds.features)) - extra_cols.remove(label_column) - - if label_column != "label": - ds = ds.rename_column(label_column, "label") - if num_shots > 0: - fewshot = FewShotSampler( - train_ds, # TODO: not iterator - num_shots=num_shots, - rng=rng, - ) - fewshot_iter = iter(fewshot) - else: - fewshot_iter = None - - try: - example = next(ds_iterator) - except StopIteration: - return - - example = _convert_to_prompts( - example, - label_column=label_column, - num_classes=num_classes, - num_variants=num_variants, - prompter=prompter, - rng=rng, - fewshot_iter=fewshot_iter, - ) - - # Add the builder and config name to the records directly to make - # sure we don't forget what dataset they came from. - example["builder_name"] = ds.info.builder_name - example["config_name"] = ds.info.config_name + label_column = label_column or infer_label_column(ds.features) + num_classes = num_classes or infer_num_classes(ds.features[label_column]) + rng = Random(seed) - yield example + if num_shots > 0: + fewshot = FewShotSampler( + train_ds, # TODO: not iterator + num_shots=num_shots, + rng=rng, + ) + fewshot_iter = iter(fewshot) + else: + fewshot_iter = None + + # Remove everything except the label column + extra_cols = list(assert_type(Features, ds.features)) + extra_cols.remove(label_column) + + for example in BalancedSampler(ds, num_classes, label_col=label_column): + yield _convert_to_prompts( + example, + label_column=label_column, + num_classes=num_classes, + num_variants=num_variants, + prompter=prompter, + rng=rng, + fewshot_iter=fewshot_iter, + ) def _convert_to_prompts( @@ -203,7 +200,7 @@ def _convert_to_prompts( fewshot_iter: Optional[Iterator[list[dict]]] = None, ) -> dict[str, Any]: """Prompt-generating function to pass to `IterableDataset.map`.""" - assert_type(int, example[label_column]) + labels_are_strings = isinstance(example[label_column], str) prompts = [] templates = list(prompter.templates.values()) if num_variants < len(templates): @@ -216,22 +213,24 @@ def qa_cat(q: str, a: str) -> str: # For sanity checking that prompts are unique prompt_counter = Counter() - new_label = rng.choice([0, 1]) if num_classes > 2 else example[label_column] + label_indices = set() for template in templates: choices = [] + string_choices = template.get_answer_choices_list(example) - if num_classes > 2: - template = binarize( - template, example[label_column], assert_type(int, new_label), rng - ) + label = example[label_column] + label_indices.add(string_choices.index(label) if labels_are_strings else label) - for answer_idx in range(2): + for answer_idx in range(num_classes): fake_example = example.copy() - fake_example[label_column] = answer_idx + if labels_are_strings: + fake_example[label_column] = string_choices[answer_idx] + else: + fake_example[label_column] = answer_idx q, a = template.apply(fake_example) - text = qa_cat(q, a) + text = qa_cat(q, a or string_choices[answer_idx]) prompt_counter[text] += 1 if fewshot_iter is not None: @@ -258,8 +257,14 @@ def qa_cat(q: str, a: str) -> str: if dup_count > 1: raise ValueError(f'Prompt duplicated {dup_count} times! "{maybe_dup}"') + # Sanity check: label should be the same across all variants + if len(label_indices) > 1: + raise ValueError( + f"Label index should be the same all variants, but got {label_indices}" + ) + return dict( - label=new_label, + label=label_indices.pop(), prompts=prompts, - template_names=prompter.all_template_names, + template_names=[template.name for template in templates], ) diff --git a/elk/files.py b/elk/files.py index a9da71e9..4435dee1 100644 --- a/elk/files.py +++ b/elk/files.py @@ -5,9 +5,6 @@ import random from pathlib import Path -import yaml -from simple_parsing import Serializable - def elk_reporter_dir() -> Path: """Return the directory where reporter checkpoints and logs are stored.""" @@ -41,28 +38,3 @@ def memorably_named_dir(parent: Path): out_dir = parent / sub_dir out_dir.mkdir(parents=True, exist_ok=True) return out_dir - - -def save_config(cfg: Serializable, out_dir: Path): - """Save the config to a file""" - - path = out_dir / "cfg.yaml" - with open(path, "w") as f: - cfg.dump_yaml(f) - - return path - - -def save_meta(dataset, out_dir: Path): - """Save the meta data to a file""" - - meta = { - "dataset_fingerprints": { - split: dataset[split]._fingerprint for split in dataset.keys() - } - } - path = out_dir / "metadata.yaml" - with open(path, "w") as meta_f: - yaml.dump(meta, meta_f) - - return path diff --git a/elk/logging.py b/elk/logging.py deleted file mode 100644 index 706055bd..00000000 --- a/elk/logging.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging - -from .utils import select_train_val_splits - - -def save_debug_log(ds, out_dir): - """ - Save a debug log to the output directory. This is useful for debugging - training issues. - """ - - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s %(levelname)s:\n%(message)s", - filename=out_dir / "debug.log", - filemode="w", - ) - - train_split, val_split = select_train_val_splits(ds) - text_inputs = ds[val_split][0]["text_inputs"] - template_ids = ds[val_split][0]["variant_ids"] - label = ds[val_split][0]["label"] - - # log the train size and val size - logging.info(f"Train size: {len(ds[train_split])}") - logging.info(f"Val size: {len(ds[val_split])}") - - templates_text = f"{len(text_inputs)} templates used:\n" - trailing_whitespace = False - for (text0, text1), id in zip(text_inputs, template_ids): - templates_text += ( - f'***---TEMPLATE "{id}"---***\n' - f"{'false' if label else 'true'}:\n" - f'"""{text0}"""\n' - f"{'true' if label else 'false'}:\n" - f'"""{text1}"""\n\n\n' - ) - if text0[-1].isspace() or text1[-1].isspace(): - trailing_whitespace = True - if trailing_whitespace: - logging.warning( - "Some inputs to the model have trailing whitespace! " - "Check that the jinja templates are not adding " - "trailing whitespace. If `token_loc` is 'last', this " - "will extract hidden states from the whitespace token." - ) - logging.info(templates_text) diff --git a/elk/metrics.py b/elk/metrics.py new file mode 100644 index 00000000..0150f02f --- /dev/null +++ b/elk/metrics.py @@ -0,0 +1,158 @@ +from typing import NamedTuple + +import torch +from torch import Tensor + + +def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: + """ + Convert a tensor of class labels to a one-hot representation. + + Args: + labels (Tensor): A tensor of class labels of shape (N,). + n_classes (int): The total number of unique classes. + + Returns: + Tensor: A one-hot representation tensor of shape (N, n_classes). + """ + one_hot_labels = labels.new_zeros(labels.size(0), n_classes) + return one_hot_labels.scatter_(1, labels.unsqueeze(1).long(), 1) + + +def accuracy(y_true: Tensor, y_pred: Tensor) -> float: + """ + Compute the accuracy of a classification model. + + Args: + y_true: Ground truth tensor of shape (N,). + y_pred: Predicted class tensor of shape (N,) or (N, n_classes). + + Returns: + float: Accuracy of the model. + """ + # Check if binary or multi-class classification + if len(y_pred.shape) == 1: + hard_preds = y_pred > 0.5 + else: + hard_preds = y_pred.argmax(-1) + + return hard_preds.cpu().eq(y_true.cpu()).float().mean().item() + + +class RocAucResult(NamedTuple): + """Named tuple for storing ROC AUC results.""" + + estimate: float + """Point estimate of the ROC AUC computed on this sample.""" + lower: float + """Lower bound of the bootstrap confidence interval.""" + upper: float + """Upper bound of the bootstrap confidence interval.""" + + +def roc_auc(y_true: Tensor, y_pred: Tensor) -> Tensor: + """Area under the receiver operating characteristic curve (ROC AUC). + + Unlike scikit-learn's implementation, this function supports batched inputs of + shape `(N, n)` where `N` is the number of datasets and `n` is the number of samples + within each dataset. This is primarily useful for efficiently computing bootstrap + confidence intervals. + + Args: + y_true: Ground truth tensor of shape `(N,)` or `(N, n)`. + y_pred: Predicted class tensor of shape `(N,)` or `(N, n)`. + + Returns: + Tensor: If the inputs are 1D, a scalar containing the ROC AUC. If they're 2D, + a tensor of shape (N,) containing the ROC AUC for each dataset. + """ + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true and y_pred should have the same shape; " + f"got {y_true.shape} and {y_pred.shape}" + ) + if y_true.dim() not in (1, 2): + raise ValueError("y_true and y_pred should be 1D or 2D tensors") + + # Sort y_pred in descending order and get indices + indices = y_pred.argsort(descending=True, dim=-1) + + # Reorder y_true based on sorted y_pred indices + y_true_sorted = y_true.gather(-1, indices) + + # Calculate number of positive and negative samples + num_positives = y_true.sum(dim=-1) + num_negatives = y_true.shape[-1] - num_positives + + # Calculate cumulative sum of true positive counts (TPs) + tps = torch.cumsum(y_true_sorted, dim=-1) + + # Calculate cumulative sum of false positive counts (FPs) + fps = torch.cumsum(1 - y_true_sorted, dim=-1) + + # Calculate true positive rate (TPR) and false positive rate (FPR) + tpr = tps / num_positives.view(-1, 1) + fpr = fps / num_negatives.view(-1, 1) + + # Calculate differences between consecutive FPR values (widths of trapezoids) + fpr_diffs = torch.cat( + [fpr[..., 1:] - fpr[..., :-1], torch.zeros_like(fpr[..., :1])], dim=-1 + ) + + # Calculate area under the ROC curve for each dataset using trapezoidal rule + return torch.sum(tpr * fpr_diffs, dim=-1).squeeze() + + +def roc_auc_ci( + y_true: Tensor, + y_pred: Tensor, + *, + num_samples: int = 1000, + level: float = 0.95, + seed: int = 42, +) -> RocAucResult: + """Bootstrap confidence interval for the ROC AUC. + + Args: + y_true: Ground truth tensor of shape `(N,)`. + y_pred: Predicted class tensor of shape `(N,)`. + num_samples (int): Number of bootstrap samples to use. + level (float): Confidence level of the confidence interval. + seed (int): Random seed for reproducibility. + + Returns: + RocAucResult: Named tuple containing the lower and upper bounds of the + confidence interval, along with the point estimate. + """ + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true and y_pred should have the same shape; " + f"got {y_true.shape} and {y_pred.shape}" + ) + if y_true.dim() != 1: + raise ValueError("y_true and y_pred should be 1D tensors") + + device = y_true.device + N = y_true.shape[0] + + # Generate random indices for bootstrap samples (shape: [num_bootstraps, N]) + rng = torch.Generator(device=device).manual_seed(seed) + indices = torch.randint(0, N, (num_samples, N), device=device, generator=rng) + + # Create bootstrap samples of true labels and predicted probabilities + y_true_bootstraps = y_true[indices] + y_pred_bootstraps = y_pred[indices] + + # Compute ROC AUC scores for bootstrap samples + bootstrap_aucs = roc_auc(y_true_bootstraps, y_pred_bootstraps) + + # Calculate the lower and upper bounds of the confidence interval. We use + # nanquantile instead of quantile because some bootstrap samples may have + # NaN values due to the fact that they have only one class. + alpha = (1 - level) / 2 + q = y_pred.new_tensor([alpha, 1 - alpha]) + lower, upper = bootstrap_aucs.nanquantile(q).tolist() + + # Compute the point estimate + estimate = roc_auc(y_true, y_pred).item() + return RocAucResult(estimate, lower, upper) diff --git a/elk/promptsource/templates/dbpedia_14/templates.yaml b/elk/promptsource/templates/dbpedia_14/templates.yaml index c16b0e70..4a19bb61 100644 --- a/elk/promptsource/templates/dbpedia_14/templates.yaml +++ b/elk/promptsource/templates/dbpedia_14/templates.yaml @@ -53,25 +53,6 @@ templates: original_task: true name: burns_3 reference: Burns et al. - 03fa401f-3329-48fa-be4a-1b6725292ee6: !Template - answer_choices: Company ||| Educational Institution ||| Artist ||| Athlete ||| - Office Holder ||| Mean Of Transportation ||| Building ||| Natural Place ||| - Village ||| Animal ||| Plant ||| Album ||| Film ||| Written Work - id: 03fa401f-3329-48fa-be4a-1b6725292ee6 - jinja: 'Consider the following example: '''''' {{content}} '''''' - - Which is the topic of this example, choice 1: {{answer_choices[label]}}, or - choice 2: {{answer_choices[1 - label]}} ||| - - {{answer_choices[label]}}' - metadata: !TemplateMetadata - choices_in_prompt: true - languages: - - en - metrics: [] - original_task: true - name: burns_4 - reference: Burns et al. 04fa401f-3329-48fa-be4a-1b6725292ee6: !Template answer_choices: Company ||| Educational Institution ||| Artist ||| Athlete ||| Office Holder ||| Mean Of Transportation ||| Building ||| Natural Place ||| diff --git a/elk/promptsource/templates/sst2/templates.yaml b/elk/promptsource/templates/sst2/templates.yaml new file mode 100644 index 00000000..92d82916 --- /dev/null +++ b/elk/promptsource/templates/sst2/templates.yaml @@ -0,0 +1,224 @@ +# Taken from "Overthinking the Truth" by Halawi et al. (2023) +# Translated into Promptsource YAML by GPT-4 +dataset: sst-2 +subset: default +templates: + a1: !Template + answer_choices: Negative ||| Positive + id: a1 + jinja: 'Review: {{sentence}} + + Answer: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_a1 + reference: '' + a2: !Template + answer_choices: bad ||| good + id: a2 + jinja: 'Review: {{sentence}} + + Answer: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_a2 + reference: '' + a3: !Template + answer_choices: bad ||| good + id: a3 + jinja: 'My review for last night''s film: {{sentence}} The critics agreed that this movie was |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: last_night_film + reference: '' + a4: !Template + answer_choices: negative ||| positive + id: a4 + jinja: 'One of our critics wrote "{{sentence}}". Her sentiment towards the film was |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: critic_sentiment + reference: '' + a5: !Template + answer_choices: bad ||| good + id: a5 + jinja: 'In a contemporary review, Roger Ebert wrote: "{{sentence}}". Entertainment Weekly agreed, and the overall critical reception of the film was |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: roger_ebert + reference: '' + a6: !Template + answer_choices: No ||| Yes + id: a6 + jinja: 'Review: {{sentence}} + + Positive Review? |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_positive + reference: '' + a7: !Template + answer_choices: Negative ||| Positive + id: a7 + jinja: 'Review: {{sentence}} + + Question: Is the sentiment of the above review Positive or Negative? + + Answer: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_sentiment_question + reference: '' + a8: !Template + answer_choices: bad ||| good + id: a8 + jinja: 'Review: {{sentence}} + + Question: Did the author think that the movie was good or bad? + + Answer: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: author_opinion_question + reference: '' + a9: !Template + answer_choices: bad ||| good + id: a9 + jinja: 'Question: Did the author of the following tweet think that the movie was good or bad? + + Tweet: {{sentence}} + + Answer: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: tweet_opinion_question + reference: '' + a10: !Template + answer_choices: bad ||| good + id: a10 + jinja: '{{sentence}} My overall feeling was that the movie was |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: overall_feeling + reference: '' + a11: !Template + answer_choices: liked ||| hated + id: a11 + jinja: '{{sentence}} I |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: liked_or_hated + reference: '' + a12: !Template + answer_choices: 0 ||| 5 + id: a12 + jinja: '{{sentence}} My friend asked me if I would give the movie 0 or 5 stars, I said |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: movie_stars + reference: '' + a13: !Template + answer_choices: Negative ||| Positive + id: a13 + jinja: 'Input: {{sentence}} + + Sentiment: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: input_sentiment + reference: '' + a14: !Template + answer_choices: False ||| True + id: a14 + jinja: 'Review: {{sentence}} + + Positive: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_positive_boolean + reference: '' + a15: !Template + answer_choices: 0 ||| 5 + id: a15 + jinja: 'Review: {{sentence}} + + Stars: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_stars + reference: '' diff --git a/elk/run.py b/elk/run.py index af75c597..bc889baa 100644 --- a/elk/run.py +++ b/elk/run.py @@ -3,27 +3,27 @@ from abc import ABC from dataclasses import dataclass, field from pathlib import Path -from typing import ( - TYPE_CHECKING, - Callable, - Optional, - Union, -) +from typing import TYPE_CHECKING, Callable, Literal, Union import numpy as np import pandas as pd import torch import torch.multiprocessing as mp +import yaml from datasets import DatasetDict from torch import Tensor from tqdm import tqdm +from .debug_logging import save_debug_log from .extraction import extract -from .files import elk_reporter_dir, memorably_named_dir, save_config, save_meta -from .logging import save_debug_log -from .training.preprocessing import normalize -from .utils import assert_type, int16_to_float32 -from .utils.data_utils import get_layers, select_train_val_splits +from .files import elk_reporter_dir, memorably_named_dir +from .utils import ( + assert_type, + get_dataset_name, + get_layers, + int16_to_float32, + select_train_val_splits, +) if TYPE_CHECKING: from .evaluation.evaluate import Eval @@ -33,12 +33,14 @@ @dataclass class Run(ABC): cfg: Union["Elicit", "Eval"] - out_dir: Optional[Path] = None - dataset: DatasetDict = field(init=False) + out_dir: Path | None = None + datasets: list[DatasetDict] = field(init=False) def __post_init__(self): - # Extract the hidden states first if necessary - self.dataset = extract(self.cfg.data, num_gpus=self.cfg.num_gpus) + self.datasets = [ + extract(cfg, num_gpus=self.cfg.num_gpus, min_gpu_mem=self.cfg.min_gpu_mem) + for cfg in self.cfg.data.explode() + ] if self.out_dir is None: # Save in a memorably-named directory inside of @@ -52,8 +54,21 @@ def __post_init__(self): print(f"Output directory at \033[1m{self.out_dir}\033[0m") self.out_dir.mkdir(parents=True, exist_ok=True) - save_config(self.cfg, self.out_dir) - save_meta(self.dataset, self.out_dir) + path = self.out_dir / "cfg.yaml" + with open(path, "w") as f: + self.cfg.dump_yaml(f) + + path = self.out_dir / "fingerprints.yaml" + with open(path, "w") as meta_f: + yaml.dump( + { + get_dataset_name(ds): { + split: ds[split]._fingerprint for split in ds.keys() + } + for ds in self.datasets + }, + meta_f, + ) def make_reproducible(self, seed: int): """Make the run reproducible by setting the random seed.""" @@ -70,55 +85,41 @@ def get_device(self, devices, world_size: int) -> str: return device def prepare_data( - self, - device: str, - layer: int, - ) -> tuple: - """Prepare the data for training and validation.""" - - with self.dataset.formatted_as("torch", device=device, dtype=torch.int16): - train_split, val_split = select_train_val_splits(self.dataset) - train, val = self.dataset[train_split], self.dataset[val_split] - - train_labels = assert_type(Tensor, train["label"]) - val_labels = assert_type(Tensor, val["label"]) - - # Note: currently we're just upcasting to float32 - # so we don't have to deal with - # grad scaling (which isn't supported for LBFGS), - # while the hidden states are - # saved in float16 to save disk space. - # In the future we could try to use mixed - # precision training in at least some cases. - train_h, val_h = normalize( - int16_to_float32(assert_type(torch.Tensor, train[f"hidden_{layer}"])), - int16_to_float32(assert_type(torch.Tensor, val[f"hidden_{layer}"])), - method=self.cfg.normalization, - ) + self, device: str, layer: int, split_type: Literal["train", "val"] + ) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]: + """Prepare data for the specified layer and split type.""" + out = {} + + for ds in self.datasets: + train_name, val_name = select_train_val_splits(ds) + key = train_name if split_type == "train" else val_name - x0, x1 = train_h.unbind(dim=-2) - val_x0, val_x1 = val_h.unbind(dim=-2) + split = ds[key].with_format("torch", device=device, dtype=torch.int16) + labels = assert_type(Tensor, split["label"]) + val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) - with self.dataset.formatted_as("numpy"): - has_preds = "model_preds" in val.features - val_lm_preds = val["model_preds"] if has_preds else None + with split.formatted_as("torch", device=device): + has_preds = "model_preds" in split.features + lm_preds = split["model_preds"] if has_preds else None - return x0, x1, val_x0, val_x1, train_labels, val_labels, val_lm_preds + ds_name = get_dataset_name(ds) + out[ds_name] = (val_h, labels, lm_preds) + + return out def concatenate(self, layers): """Concatenate hidden states from a previous layer.""" for layer in range(self.cfg.concatenated_layer_offset, len(layers)): - layers[layer] = layers[layer] + [ - layers[layer][0] - self.cfg.concatenated_layer_offset - ] + layers[layer] += [layers[layer][0] - self.cfg.concatenated_layer_offset] + return layers def apply_to_layers( self, - func: Callable[[int], pd.Series], + func: Callable[[int], pd.DataFrame], num_devices: int, ): - """Apply a function to each layer of the dataset in parallel + """Apply a function to each layer of the datasets in parallel and writes the results to a CSV file. Args: @@ -128,7 +129,8 @@ def apply_to_layers( """ self.out_dir = assert_type(Path, self.out_dir) - layers: list[int] = get_layers(self.dataset) + layers, *rest = [get_layers(ds) for ds in self.datasets] + assert all(x == layers for x in rest), "All datasets must have the same layers" if self.cfg.concatenated_layer_offset > 0: layers = self.concatenate(layers) @@ -137,14 +139,15 @@ def apply_to_layers( ctx = mp.get_context("spawn") with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: mapper = pool.imap_unordered if num_devices > 1 else map - row_buf = [] + df_buf = [] try: - for row in tqdm(mapper(func, layers), total=len(layers)): - row_buf.append(row) + for df in tqdm(mapper(func, layers), total=len(layers)): + df_buf.append(df) finally: # Make sure the CSV is written even if we crash or get interrupted - df = pd.DataFrame(row_buf).sort_values(by="layer") - df.to_csv(f, index=False) + if df_buf: + df = pd.concat(df_buf).sort_values(by="layer") + df.to_csv(f, index=False) if self.cfg.debug: - save_debug_log(self.dataset, self.out_dir) + save_debug_log(self.datasets, self.out_dir) diff --git a/elk/training/__init__.py b/elk/training/__init__.py index 41264179..ce6e4d48 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -1,13 +1,17 @@ from .ccs_reporter import CcsReporter, CcsReporterConfig +from .classifier import Classifier from .eigen_reporter import EigenReporter, EigenReporterConfig +from .normalizer import Normalizer from .reporter import OptimConfig, Reporter, ReporterConfig __all__ = [ - "Reporter", - "ReporterConfig", "CcsReporter", "CcsReporterConfig", + "Classifier", "EigenReporter", "EigenReporterConfig", + "Normalizer", "OptimConfig", + "Reporter", + "ReporterConfig", ] diff --git a/elk/training/baseline.py b/elk/training/baseline.py deleted file mode 100644 index 2c9542a6..00000000 --- a/elk/training/baseline.py +++ /dev/null @@ -1,62 +0,0 @@ -import pickle -from pathlib import Path -from typing import Tuple - -import torch -from sklearn.metrics import accuracy_score, roc_auc_score -from torch import Tensor - -from ..utils.typing import assert_type -from .classifier import Classifier - -# TODO: Create class for baseline? - - -def evaluate_baseline( - lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_labels: Tensor -) -> Tuple[float, float]: - X = torch.cat([val_x0, val_x1]) - d = X.shape[-1] - X_val = X.view(-1, d) - with torch.no_grad(): - lr_preds = lr_model(X_val).sigmoid().cpu() - - val_labels_aug = ( - torch.cat([val_labels, 1 - val_labels]).repeat_interleave(val_x0.shape[1]) - ).cpu() - - lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) - lr_auroc = roc_auc_score(val_labels_aug, lr_preds) - - return assert_type(float, lr_auroc), assert_type(float, lr_acc) - - -def train_baseline( - x0: Tensor, - x1: Tensor, - train_labels: Tensor, - device: str, -) -> Classifier: - # repeat_interleave makes `num_variants` copies of each label, all within a - # single dimension of size `num_variants * 2 * n`, such that the labels align - # with X.view(-1, X.shape[-1]) - train_labels_aug = torch.cat([train_labels, 1 - train_labels]).repeat_interleave( - x0.shape[1] - ) - - X = torch.cat([x0, x1]).squeeze() - d = X.shape[-1] - lr_model = Classifier(d, device=device) - lr_model.fit_cv(X.view(-1, d), train_labels_aug) - - return lr_model - - -def save_baseline(lr_dir: Path, layer: int, lr_model: Classifier): - with open(lr_dir / f"layer_{layer}.pt", "wb") as file: - pickle.dump(lr_model, file) - - -def load_baseline(lr_dir: Path, layer: int) -> Classifier: - with open(lr_dir / f"layer_{layer}.pt", "rb") as file: - return pickle.load(file) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 258f2bc4..b839e3fd 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -10,9 +10,12 @@ from torch import Tensor from torch.nn.functional import binary_cross_entropy as bce +from ..metrics import roc_auc from ..parsing import parse_loss from ..utils.typing import assert_type +from .classifier import Classifier from .losses import LOSSES +from .normalizer import Normalizer from .reporter import Reporter, ReporterConfig @@ -34,6 +37,7 @@ class CcsReporterConfig(ReporterConfig): Example: --loss 1.0*consistency_squared 0.5*prompt_var corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var. Defaults to "ccs_prompt_var". + normalization: The kind of normalization to apply to the hidden states. num_layers: The number of layers in the MLP. Defaults to 1. pre_ln: Whether to include a LayerNorm module before the first linear layer. Defaults to False. @@ -83,15 +87,19 @@ class CcsReporter(Reporter): def __init__( self, - in_features: int, cfg: CcsReporterConfig, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, + in_features: int, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, ): - super().__init__(in_features, cfg, device=device, dtype=dtype) + super().__init__() + self.config = cfg hidden_size = cfg.hidden_size or 4 * in_features // 3 + self.neg_norm = Normalizer((in_features,), device=device, dtype=dtype) + self.pos_norm = Normalizer((in_features,), device=device, dtype=dtype) + self.probe = nn.Sequential( nn.Linear( in_features, @@ -120,6 +128,56 @@ def __init__( ) ) + def check_separability( + self, + train_pair: tuple[Tensor, Tensor], + val_pair: tuple[Tensor, Tensor], + ) -> float: + """Measure how linearly separable the pseudo-labels are for a contrast pair. + + Args: + train_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + contrastive representations. Used for training the classifier. + val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + contrastive representations. Used for evaluating the classifier. + + Returns: + The AUROC of a linear classifier fit on the pseudo-labels. + """ + _x0, _x1 = train_pair + _val_x0, _val_x1 = val_pair + + x0, x1 = self.neg_norm(_x0), self.pos_norm(_x1) + val_x0, val_x1 = self.neg_norm(_val_x0), self.pos_norm(_val_x1) + + pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore + pseudo_train_labels = torch.cat( + [ + x0.new_zeros(x0.shape[0]), + x0.new_ones(x0.shape[0]), + ] + ).repeat_interleave( + x0.shape[1] + ) # make num_variants copies of each pseudo-label + pseudo_val_labels = torch.cat( + [ + val_x0.new_zeros(val_x0.shape[0]), + val_x0.new_ones(val_x0.shape[0]), + ] + ).repeat_interleave(val_x0.shape[1]) + + pseudo_clf.fit( + # b v d -> (b v) d + torch.cat([x0, x1]).flatten(0, 1), + pseudo_train_labels, + ) + with torch.no_grad(): + pseudo_preds = pseudo_clf( + # b v d -> (b v) d + torch.cat([val_x0, val_x1]).flatten(0, 1) + ).squeeze(-1) + return roc_auc(pseudo_val_labels, pseudo_preds).item() + def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: loss = sum( LOSSES[name](logit0, logit1, coef) @@ -164,9 +222,6 @@ def forward(self, x: Tensor) -> Tensor: """Return the raw score output of the probe on `x`.""" return self.probe(x).squeeze(-1) - def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: - return 0.5 * (self(x_pos).sigmoid() + (1 - self(x_neg).sigmoid())) - def loss( self, logit0: Tensor, @@ -213,14 +268,13 @@ def loss( def fit( self, - x_pos: Tensor, - x_neg: Tensor, + hiddens: Tensor, labels: Optional[Tensor] = None, ) -> float: - """Fit the probe to the contrast pair (x0, x1). + """Fit the probe to the contrast pair (neg, pos). Args: - contrast_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + contrast_pair: A tuple of tensors, (neg, pos), where x0 and x1 are the contrastive representations. labels: The labels of the contrast pair. Defaults to None. @@ -231,8 +285,10 @@ def fit( ValueError: If `optimizer` is not "adam" or "lbfgs". RuntimeError: If the best loss is not finite. """ - # TODO: Implement normalization here to fix issue #96 - # self.update(x_pos, x_neg) + x_pos, x_neg = hiddens.unbind(2) + # Fit normalizers + self.pos_norm.fit(x_pos) + self.neg_norm.fit(x_neg) # Record the best acc, loss, and params found so far best_loss = torch.inf @@ -266,8 +322,8 @@ def fit( def train_loop_adam( self, - x_pos: Tensor, x_neg: Tensor, + x_pos: Tensor, labels: Optional[Tensor] = None, ) -> float: """Adam train loop, returning the final loss. Modifies params in-place.""" @@ -288,8 +344,8 @@ def train_loop_adam( def train_loop_lbfgs( self, - x_pos: Tensor, x_neg: Tensor, + x_pos: Tensor, labels: Optional[Tensor] = None, ) -> float: """LBFGS train loop, returning the final loss. Modifies params in-place.""" diff --git a/elk/training/classifier.py b/elk/training/classifier.py index 726cae7a..a140af44 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import torch from torch import Tensor @@ -35,14 +34,14 @@ class Classifier(torch.nn.Module): def __init__( self, input_dim: int, - num_classes: int = 1, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, + num_classes: int = 2, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, ): super().__init__() self.linear = torch.nn.Linear( - input_dim, num_classes, device=device, dtype=dtype + input_dim, num_classes if num_classes > 2 else 1, device=device, dtype=dtype ) self.linear.bias.data.zero_() self.linear.weight.data.zero_() @@ -85,7 +84,9 @@ def fit( num_classes = self.linear.out_features loss_fn = bce_with_logits if num_classes == 1 else cross_entropy loss = torch.inf - y = y.float() + y = y.to( + torch.get_default_dtype() if num_classes == 1 else torch.long, + ) def closure(): nonlocal loss @@ -148,11 +149,13 @@ def fit_cv( indices = torch.randperm(num_samples, device=x.device, generator=rng) l2_penalties = torch.logspace(-4, 4, num_penalties).tolist() - y = y.float() num_classes = self.linear.out_features loss_fn = bce_with_logits if num_classes == 1 else cross_entropy losses = x.new_zeros((k, num_penalties)) + y = y.to( + torch.get_default_dtype() if num_classes == 1 else torch.long, + ) for i in range(k): start, end = i * fold_size, (i + 1) * fold_size diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 821891cc..e45b0215 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -2,13 +2,14 @@ from dataclasses import dataclass from typing import Optional -from warnings import warn import torch +from einops import rearrange, repeat from torch import Tensor, nn, optim -from ..math_util import cov_mean_fused -from ..truncated_eigh import ConvergenceError, truncated_eigh +from ..metrics import to_one_hot +from ..truncated_eigh import truncated_eigh +from ..utils.math_util import cov_mean_fused from .reporter import Reporter, ReporterConfig @@ -18,37 +19,48 @@ class EigenReporterConfig(ReporterConfig): Args: var_weight: The weight of the variance term in the loss. - inv_weight: The weight of the invariance term in the loss. neg_cov_weight: The weight of the negative covariance term in the loss. num_heads: The number of reporter heads to fit. In other words, the number of eigenvectors to compute from the VINC matrix. """ - var_weight: float = 1.0 - inv_weight: float = 5.0 - neg_cov_weight: float = 5.0 + var_weight: float = 0.1 + neg_cov_weight: float = 0.5 num_heads: int = 1 + def __post_init__(self): + if not (0 <= self.neg_cov_weight <= 1): + raise ValueError("neg_cov_weight must be in [0, 1]") + if self.num_heads <= 0: + raise ValueError("num_heads must be positive") + class EigenReporter(Reporter): """A linear reporter whose weights are computed via eigendecomposition. Args: - in_features: The number of input features. cfg: The reporter configuration. + in_features: The number of input features. + num_classes: The number of classes for tracking the running means. If `None`, + we don't track the running means at all, and the semantics of `update()` + are a bit different. In particular, each call to `update()` is treated as a + new dataset, with a potentially different number of classes. The covariance + matrices are simply averaged over each batch of data passed to `update()`, + instead of being updated with Welford's algorithm. This is useful for + training a single reporter on multiple datasets, where the number of + classes may vary. Attributes: config: The reporter configuration. - intercluster_cov_M2: The running sum of the covariance matrices of the - centroids of the positive and negative clusters. + intercluster_cov_M2: The unnormalized covariance matrix averaged over all + classes. intracluster_cov: The running mean of the covariance matrices within each cluster. This doesn't need to be a running sum because it's doesn't use Welford's algorithm. - contrastive_xcov_M2: The running sum of the cross-covariance between the - centroids of the positive and negative clusters. - n: The running sum of the number of samples in the positive and negative - clusters. + contrastive_xcov_M2: Average of the unnormalized cross-covariance matrices + across all pairs of classes (k, k'). + n: The running sum of the number of clusters processed by `update()`. weight: The reporter weight matrix. Guaranteed to always be orthogonal, and the columns are sorted in descending order of eigenvalue magnitude. """ @@ -59,21 +71,36 @@ class EigenReporter(Reporter): intracluster_cov: Tensor # invariance contrastive_xcov_M2: Tensor # negative covariance n: Tensor + class_means: Tensor | None weight: Tensor def __init__( self, - in_features: int, cfg: EigenReporterConfig, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, + in_features: int, + num_classes: int | None = 2, + *, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, ): - super().__init__(in_features, cfg, device=device, dtype=dtype) + super().__init__() + self.config = cfg # Learnable Platt scaling parameters self.bias = nn.Parameter(torch.zeros(cfg.num_heads, device=device, dtype=dtype)) self.scale = nn.Parameter(torch.ones(cfg.num_heads, device=device, dtype=dtype)) + # Running statistics + self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long)) + self.register_buffer( + "class_means", + ( + torch.zeros(num_classes, in_features, device=device, dtype=dtype) + if num_classes is not None + else None + ), + ) + self.register_buffer( "contrastive_xcov_M2", torch.zeros(in_features, in_features, device=device, dtype=dtype), @@ -86,20 +113,18 @@ def __init__( "intracluster_cov", torch.zeros(in_features, in_features, device=device, dtype=dtype), ) + + # Reporter weights self.register_buffer( "weight", torch.zeros(cfg.num_heads, in_features, device=device, dtype=dtype), ) - def forward(self, x: Tensor) -> Tensor: + def forward(self, hiddens: Tensor) -> Tensor: """Return the predicted log odds on input `x`.""" - raw_scores = x @ self.weight.mT + raw_scores = hiddens @ self.weight.mT return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: - """Return the predicted log odds on the contrast pair `(x_pos, x_neg)`.""" - return 0.5 * (self(x_pos) - self(x_neg)) - @property def contrastive_xcov(self) -> Tensor: return self.contrastive_xcov_M2 / self.n @@ -110,15 +135,15 @@ def intercluster_cov(self) -> Tensor: @property def confidence(self) -> Tensor: - return self.weight.mT @ self.intercluster_cov @ self.weight + return self.weight @ self.intercluster_cov @ self.weight.mT @property def invariance(self) -> Tensor: - return -self.weight.mT @ self.intracluster_cov @ self.weight + return -self.weight @ self.intracluster_cov @ self.weight.mT @property def consistency(self) -> Tensor: - return -self.weight.mT @ self.contrastive_xcov @ self.weight + return -self.weight @ self.contrastive_xcov @ self.weight.mT def clear(self) -> None: """Clear the running statistics of the reporter.""" @@ -128,64 +153,68 @@ def clear(self) -> None: self.n.zero_() @torch.no_grad() - def update(self, x_pos: Tensor, x_neg: Tensor) -> None: + def update(self, hiddens: Tensor) -> None: + (n, _, k, d) = hiddens.shape + # Sanity checks - assert x_pos.ndim == 3, "x_pos must be of shape [batch, num_variants, d]" - assert x_pos.shape == x_neg.shape, "x_pos and x_neg must have the same shape" - - # Average across variants inside each cluster, computing the centroids. - pos_centroids, neg_centroids = x_pos.mean(1), x_neg.mean(1) - - # We don't actually call super because we need access to the earlier estimate - # of the population mean in order to update (cross-)covariances properly - # super().update(x_pos, x_neg) - - sample_n = pos_centroids.shape[0] - self.n += sample_n - - # Update the running means; super().update() does this usually - neg_delta = neg_centroids - self.neg_mean - pos_delta = pos_centroids - self.pos_mean - self.neg_mean += neg_delta.sum(dim=0) / self.n - self.pos_mean += pos_delta.sum(dim=0) / self.n - - # *** Variance (inter-cluster) *** - # See code at https://bit.ly/3YC9BhH, as well as "Welford's online algorithm" - # in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. - # Post-mean update deltas are used to update the (co)variance - neg_delta2 = neg_centroids - self.neg_mean # [n, d] - pos_delta2 = pos_centroids - self.pos_mean # [n, d] - self.intercluster_cov_M2.addmm_(neg_delta.mT, neg_delta2) - self.intercluster_cov_M2.addmm_(pos_delta.mT, pos_delta2) + assert k > 1, "Must provide at least two hidden states" + assert hiddens.ndim == 4, "Must be of shape [batch, variants, choices, dim]" + + self.n += n # *** Invariance (intra-cluster) *** # This is just a standard online *mean* update, since we're computing the # mean of covariance matrices, not the covariance matrix of means. - sample_invar = cov_mean_fused(x_pos) + cov_mean_fused(x_neg) - self.intracluster_cov += (sample_n / self.n) * ( - sample_invar - self.intracluster_cov - ) - - # *** Negative covariance *** - self.contrastive_xcov_M2.addmm_(neg_delta.mT, pos_delta2) - self.contrastive_xcov_M2.addmm_(pos_delta.mT, neg_delta2) - - def fit_streaming(self) -> float: + intra_cov = cov_mean_fused(rearrange(hiddens, "n v k d -> (n k) v d")) + self.intracluster_cov += (n / self.n) * (intra_cov - self.intracluster_cov) + + # [n, v, k, d] -> [n, k, d] + centroids = hiddens.mean(1) + deltas, deltas2 = [], [] + + # Iterating over classes + for i, h in enumerate(centroids.unbind(1)): + # Update the running means if needed + if self.class_means is not None: + delta = h - self.class_means[i] + self.class_means[i] += delta.sum(dim=0) / self.n + + # Post-mean update deltas are used to update the (co)variance + delta2 = h - self.class_means[i] # [n, d] + else: + delta = h - h.mean(dim=0) + delta2 = delta + + # *** Variance (inter-cluster) *** + # See code at https://bit.ly/3YC9BhH and "Welford's online algorithm" + # in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. + self.intercluster_cov_M2.addmm_(delta.mT, delta2, alpha=1 / k) + deltas.append(delta) + deltas2.append(delta2) + + # *** Negative covariance (contrastive) *** + # Iterating over pairs of classes (k, k') where k != k' + for i, d in enumerate(deltas): + for j, d_ in enumerate(deltas2): + # Compare to the other classes only + if i == j: + continue + + scale = 1 / (k * (k - 1)) + self.contrastive_xcov_M2.addmm_(d.mT, d_, alpha=scale) + + def fit_streaming(self, truncated: bool = False) -> float: """Fit the probe using the current streaming statistics.""" + inv_weight = 1 - self.config.neg_cov_weight A = ( self.config.var_weight * self.intercluster_cov - - self.config.inv_weight * self.intracluster_cov + - inv_weight * self.intracluster_cov - self.config.neg_cov_weight * self.contrastive_xcov ) - try: + if truncated: L, Q = truncated_eigh(A, k=self.config.num_heads) - except (ConvergenceError, RuntimeError): - warn( - "Truncated eigendecomposition failed to converge. Falling back on " - "PyTorch's dense eigensolver." - ) - + else: L, Q = torch.linalg.eigh(A) L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :] @@ -194,50 +223,51 @@ def fit_streaming(self) -> float: def fit( self, - x_pos: Tensor, - x_neg: Tensor, + hiddens: Tensor, labels: Optional[Tensor] = None, - *, - platt_scale: bool = True, ) -> float: - """Fit the probe to the contrast pair (x_pos, x_neg). + """Fit the probe to the contrast set `hiddens`. Args: - x_pos: The positive examples. - x_neg: The negative examples. + hiddens: The contrast set of shape [batch, variants, choices, dim]. labels: The ground truth labels if available. - platt_scale: Whether to fit the scale and bias terms to data with LBFGS. - This is only used if labels are available. Returns: loss: Negative eigenvalue associated with the VINC direction. """ - assert x_pos.shape == x_neg.shape - self.update(x_pos, x_neg) + self.update(hiddens) loss = self.fit_streaming() - if labels is not None and platt_scale: - self.platt_scale(labels, x_pos, x_neg) + + if labels is not None: + (_, v, k, _) = hiddens.shape + hiddens = rearrange(hiddens, "n v k d -> (n v k) d") + labels = to_one_hot(repeat(labels, "n -> (n v)", v=v), k).flatten() + + self.platt_scale(labels, hiddens) return loss - def platt_scale( - self, labels: Tensor, x_pos: Tensor, x_neg: Tensor, max_iter: int = 100 - ): - """Fit the scale and bias terms to data with LBFGS.""" + def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100): + """Fit the scale and bias terms to data with LBFGS. + Args: + labels: Binary labels of shape [batch]. + hiddens: Hidden states of shape [batch, dim]. + max_iter: Maximum number of iterations for LBFGS. + """ opt = optim.LBFGS( [self.bias, self.scale], line_search_fn="strong_wolfe", max_iter=max_iter, - tolerance_change=torch.finfo(x_pos.dtype).eps, - tolerance_grad=torch.finfo(x_pos.dtype).eps, + tolerance_change=torch.finfo(hiddens.dtype).eps, + tolerance_grad=torch.finfo(hiddens.dtype).eps, ) - labels = labels.repeat_interleave(x_pos.shape[1]).float() def closure(): opt.zero_grad() - logits = self.predict(x_pos, x_neg).flatten() - loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) + loss = nn.functional.binary_cross_entropy_with_logits( + self(hiddens), labels.float() + ) loss.backward() return float(loss) diff --git a/elk/training/normalizer.py b/elk/training/normalizer.py new file mode 100644 index 00000000..7cb04b97 --- /dev/null +++ b/elk/training/normalizer.py @@ -0,0 +1,63 @@ +from typing import Literal + +import torch +from torch import Tensor, nn + + +class Normalizer(nn.Module): + """Basically `BatchNorm` with a less annoying default axis ordering.""" + + mean: Tensor + std: Tensor | None + + def __init__( + self, + normalized_shape: tuple[int, ...], + *, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + eps: float = 1e-5, + mode: Literal["none", "meanonly", "full"] = "full", + ): + super().__init__() + + self.eps = eps + self.mode = mode + self.normalized_shape = normalized_shape + + self.register_buffer( + "mean", torch.zeros(*normalized_shape, device=device, dtype=dtype) + ) + self.register_buffer( + "std", torch.ones_like(self.mean) if mode == "full" else None + ) + + def forward(self, x: Tensor) -> Tensor: + """Normalize `x` using the stored mean and standard deviation.""" + if self.mode == "none": + return x + elif self.std is None: + return x - self.mean + else: + return (x - self.mean) / self.std + + def fit(self, x: Tensor) -> None: + """Update the stored mean and standard deviation.""" + + # Check shape + num_dims = len(self.normalized_shape) + if x.shape[-num_dims:] != self.normalized_shape: + raise ValueError( + f"Expected trailing sizes {self.normalized_shape} but got " + f"{x.shape[-num_dims:]}" + ) + + if self.mode == "none": + return + + dims = [i for i in range(x.ndim - num_dims)] + if self.std is None: + torch.mean(x, dim=dims, out=self.mean) + else: + variance, self.mean = torch.var_mean(x, dim=dims) + torch.sqrt(variance + self.eps, out=self.std) diff --git a/elk/training/preprocessing.py b/elk/training/preprocessing.py deleted file mode 100644 index 6081dcbb..00000000 --- a/elk/training/preprocessing.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Preprocessing functions for training.""" - -from typing import Literal - -import torch - - -def normalize( - train_hiddens: torch.Tensor, - val_hiddens: torch.Tensor, - method: Literal["legacy", "none", "elementwise", "meanonly"] = "legacy", -) -> tuple[torch.Tensor, torch.Tensor]: - """Normalize the hidden states. - - Normalize the hidden states with the specified method. - The "legacy" method is the same as the original ELK implementation. - The "elementwise" method normalizes each element. - The "meanonly" method normalizes the mean. - - Args: - train_hiddens: The hidden states for the training set. - val_hiddens: The hidden states for the validation set. - method: The normalization method to use. - - Returns: - tuple containing the training and validation hidden states. - """ - if method == "none": - return train_hiddens, val_hiddens - elif method == "legacy": - master = torch.cat([train_hiddens, val_hiddens], dim=0).float() - means = master.mean(dim=0) - - train_hiddens -= means - val_hiddens -= means - - scale = master.shape[-1] ** 0.5 / master.norm(dim=-1).mean() - train_hiddens *= scale - val_hiddens *= scale - else: - means = train_hiddens.float().mean(dim=0) - train_hiddens -= means - val_hiddens -= means - - if method == "elementwise": - scale = 1 / train_hiddens.norm(dim=0, keepdim=True) - elif method == "meanonly": - scale = 1 - else: - raise NotImplementedError(f"Scale method '{method}' is not supported.") - - train_hiddens *= scale - val_hiddens *= scale - - return train_hiddens, val_hiddens diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 9cdfb145..5e2767f5 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -7,12 +7,12 @@ import torch import torch.nn as nn +from einops import rearrange, repeat from simple_parsing.helpers import Serializable -from sklearn.metrics import roc_auc_score from torch import Tensor from ..calibration import CalibrationError -from .classifier import Classifier +from ..metrics import accuracy, roc_auc_ci, to_one_hot class EvalResult(NamedTuple): @@ -22,9 +22,12 @@ class EvalResult(NamedTuple): which contains the loss, accuracy, calibrated accuracy, and AUROC. """ + auroc: float + auroc_lower: float + auroc_upper: float + acc: float cal_acc: float - auroc: float ece: float @@ -58,97 +61,11 @@ class OptimConfig(Serializable): class Reporter(nn.Module, ABC): - """An ELK reporter network. - - Args: - in_features: The number of input features. - cfg: The reporter configuration. - """ - - n: Tensor - neg_mean: Tensor - pos_mean: Tensor - - def __init__( - self, - in_features: int, - cfg: ReporterConfig, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): - super().__init__() - - self.config = cfg - self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long)) - self.register_buffer( - "neg_mean", torch.zeros(in_features, device=device, dtype=dtype) - ) - self.register_buffer( - "pos_mean", torch.zeros(in_features, device=device, dtype=dtype) - ) - - @classmethod - def check_separability( - cls, - train_pair: tuple[Tensor, Tensor], - val_pair: tuple[Tensor, Tensor], - ) -> float: - """Measure how linearly separable the pseudo-labels are for a contrast pair. - - Args: - train_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the - contrastive representations. Used for training the classifier. - val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the - contrastive representations. Used for evaluating the classifier. - - Returns: - The AUROC of a linear classifier fit on the pseudo-labels. - """ - x0, x1 = train_pair - val_x0, val_x1 = val_pair - - pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore - pseudo_train_labels = torch.cat( - [ - x0.new_zeros(x0.shape[0]), - x0.new_ones(x0.shape[0]), - ] - ).repeat_interleave( - x0.shape[1] - ) # make num_variants copies of each pseudo-label - pseudo_val_labels = torch.cat( - [ - val_x0.new_zeros(val_x0.shape[0]), - val_x0.new_ones(val_x0.shape[0]), - ] - ).repeat_interleave(val_x0.shape[1]) - - pseudo_clf.fit( - # b v d -> (b v) d - torch.cat([x0, x1]).flatten(0, 1), - pseudo_train_labels, - ) - with torch.no_grad(): - pseudo_preds = pseudo_clf( - # b v d -> (b v) d - torch.cat([val_x0, val_x1]).flatten(0, 1) - ) - return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu())) + """An ELK reporter network.""" def reset_parameters(self): """Reset the parameters of the probe.""" - @torch.no_grad() - def update(self, x_pos: Tensor, x_neg: Tensor) -> None: - """Update the running mean of the positive and negative examples.""" - - x_pos, x_neg = x_pos.flatten(0, -2), x_neg.flatten(0, -2) - self.n += x_pos.shape[0] - - # Update the running means - self.neg_mean += (x_neg.sum(dim=0) - self.neg_mean) / self.n - self.pos_mean += (x_pos.sum(dim=0) - self.pos_mean) / self.n - # TODO: These methods will do something fancier in the future @classmethod def load(cls, path: Path | str): @@ -162,55 +79,59 @@ def save(self, path: Path | str): @abstractmethod def fit( self, - x_pos: Tensor, - x_neg: Tensor, + hiddens: Tensor, labels: Optional[Tensor] = None, ) -> float: ... - @abstractmethod - def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: - """Pool the probe output on the contrast pair (x_pos, x_neg).""" - @torch.no_grad() - def score(self, labels: Tensor, x_pos: Tensor, x_neg: Tensor) -> EvalResult: - """Score the probe on the contrast pair (x_pos, x1). + def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: + """Score the probe on the contrast set `hiddens`. Args: - x_pos: The positive examples. - x_neg: The negative examples. - labels: The labels of the contrast pair. + labels: The labels of the contrast pair. + hiddens: Contrast set of shape [n, v, k, d]. Returns: an instance of EvalResult containing the loss, accuracy, calibrated - accuracy, and AUROC of the probe on the contrast pair (x0, x1). + accuracy, and AUROC of the probe on `contrast_set`. + Accuracy: top-1 accuracy averaged over questions and variants. + Calibrated accuracy: top-1 accuracy averaged over questions and + variants, calibrated so that x% of the predictions are `True`, + where x is the proprtion of examples with ground truth label `True`. + AUROC: averaged over the n * v * c binary questions + ECE: Expected Calibration Error """ - - pred_probs = self.predict(x_pos, x_neg) - - # makes `num_variants` copies of each label, all within a single - # dimension of size `num_variants * n`, such that the labels align - # with pred_probs.flatten() - broadcast_labels = labels.repeat_interleave(pred_probs.shape[1]).float() - cal_err = ( - CalibrationError() - .update(broadcast_labels.cpu(), pred_probs.cpu()) - .compute() - ) - - # Calibrated accuracy - cal_thresh = pred_probs.float().quantile(labels.float().mean()) - cal_preds = pred_probs.gt(cal_thresh).squeeze(1).to(torch.int) - raw_preds = pred_probs.gt(0.5).squeeze(1).to(torch.int) - - # roc_auc_score only takes flattened input - auroc = float(roc_auc_score(broadcast_labels.cpu(), pred_probs.cpu().flatten())) - cal_acc = cal_preds.flatten().eq(broadcast_labels).float().mean() - raw_acc = raw_preds.flatten().eq(broadcast_labels).float().mean() - + logits = self(hiddens) + (_, v, c) = logits.shape + + # makes `num_variants` copies of each label + logits = rearrange(logits, "n v c -> (n v) c") + Y = repeat(labels, "n -> (n v)", v=v).float() + + if c == 2: + pos_probs = logits[..., 1].flatten().sigmoid() + cal_err = CalibrationError().update(Y.cpu(), pos_probs.cpu()).compute().ece + + # Calibrated accuracy + cal_thresh = pos_probs.float().quantile(labels.float().mean()) + cal_preds = pos_probs.gt(cal_thresh).to(torch.int) + cal_acc = cal_preds.flatten().eq(Y).float().mean().item() + else: + # TODO: Implement calibration error for k > 2? + cal_acc = 0.0 + cal_err = 0.0 + + Y_one_hot = to_one_hot(Y, c).long().flatten() + auroc_result = roc_auc_ci(Y_one_hot, logits.flatten()) + + raw_preds = logits.argmax(dim=-1).long() + raw_acc = accuracy(Y, raw_preds.flatten()) return EvalResult( - acc=torch.max(raw_acc, 1 - raw_acc).item(), - cal_acc=torch.max(cal_acc, 1 - cal_acc).item(), - auroc=max(auroc, 1 - auroc), - ece=cal_err.ece, + auroc=auroc_result.estimate, + auroc_lower=auroc_result.lower, + auroc_upper=auroc_result.upper, + acc=float(raw_acc), + cal_acc=cal_acc, + ece=cal_err, ) diff --git a/elk/training/supervised.py b/elk/training/supervised.py new file mode 100644 index 00000000..300d6f0e --- /dev/null +++ b/elk/training/supervised.py @@ -0,0 +1,48 @@ +import torch +from einops import rearrange, repeat +from torch import Tensor + +from ..metrics import RocAucResult, accuracy, roc_auc_ci, to_one_hot +from ..utils import assert_type +from .classifier import Classifier + + +def evaluate_supervised( + lr_model: Classifier, val_h: Tensor, val_labels: Tensor +) -> tuple[RocAucResult, float]: + (_, v, k, _) = val_h.shape + + with torch.no_grad(): + logits = rearrange(lr_model(val_h).squeeze(), "n v k -> (n v) k") + raw_preds = to_one_hot(logits.argmax(dim=-1), k).long() + + labels = repeat(val_labels, "n -> (n v)", v=v) + labels = to_one_hot(labels, k).flatten() + + lr_acc = accuracy(labels, raw_preds.flatten()) + lr_auroc = roc_auc_ci(labels, logits.flatten()) + + return lr_auroc, assert_type(float, lr_acc) + + +def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier: + Xs, train_labels = [], [] + + for train_h, labels, _ in data.values(): + (_, v, k, _) = train_h.shape + train_h = rearrange(train_h, "n v k d -> (n v k) d") + + labels = repeat(labels, "n -> (n v)", v=v) + labels = to_one_hot(labels, k).flatten() + + Xs.append(train_h) + train_labels.append(labels) + + X, train_labels = torch.cat(Xs), torch.cat(train_labels) + lr_model = Classifier(X.shape[-1], device=device) + if cv: + lr_model.fit_cv(X, train_labels) + else: + lr_model.fit(X, train_labels) + + return lr_model diff --git a/elk/training/sweep.py b/elk/training/sweep.py new file mode 100644 index 00000000..7f2bef2a --- /dev/null +++ b/elk/training/sweep.py @@ -0,0 +1,60 @@ +from dataclasses import InitVar, dataclass + +from ..extraction import Extract, PromptConfig +from ..files import elk_reporter_dir, memorably_named_dir +from .train import Elicit + + +@dataclass +class Sweep: + models: list[str] + """List of Huggingface model strings to sweep over.""" + datasets: list[str] + """List of dataset strings to sweep over. Each dataset string can contain + multiple datasets, separated by plus signs. For example, "sst2+imdb" will + pool SST-2 and IMDB together.""" + add_pooled: InitVar[bool] = False + """Whether to add a dataset that pools all of the other datasets together.""" + + name: str | None = None + + def __post_init__(self, add_pooled: bool): + if not self.datasets: + raise ValueError("No datasets specified") + if not self.models: + raise ValueError("No models specified") + + # Add an additional dataset that pools all of the datasets together. + if add_pooled: + self.datasets.append("+".join(self.datasets)) + + def execute(self): + M, D = len(self.models), len(self.datasets) + print(f"Starting sweep over {M} models and {D} datasets ({M * D} runs)") + print(f"Models: {self.models}") + print(f"Datasets: {self.datasets}") + + root_dir = elk_reporter_dir() / "sweeps" + sweep_dir = root_dir / self.name if self.name else memorably_named_dir(root_dir) + print(f"Saving sweep results to \033[1m{sweep_dir}\033[0m") # bold + + for i, model_str in enumerate(self.models): + # Magenta color for the model name + print(f"\n\033[35m===== {model_str} ({i + 1} of {M}) =====\033[0m") + + for dataset_str in self.datasets: + out_dir = sweep_dir / model_str / dataset_str + + # Allow for multiple datasets to be specified in a single string with + # plus signs. This means we can pool datasets together inside of a + # single sweep. + datasets = [ds.strip() for ds in dataset_str.split("+")] + Elicit( + data=Extract( + model=model_str, + prompts=PromptConfig( + datasets=datasets, + ), + ), + out_dir=out_dir, + ).execute() diff --git a/elk/training/train.py b/elk/training/train.py index 9be8c589..cd9dcb7c 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,25 +1,24 @@ """Main training loop.""" -import warnings from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Literal, Optional +from typing import Callable, Literal import pandas as pd import torch +from einops import rearrange, repeat from simple_parsing import Serializable, field, subgroups -from sklearn.metrics import accuracy_score, roc_auc_score -from torch import Tensor from ..extraction.extraction import Extract +from ..metrics import accuracy, roc_auc_ci, to_one_hot from ..run import Run -from ..training.baseline import evaluate_baseline, save_baseline, train_baseline +from ..training.supervised import evaluate_supervised, train_supervised from ..utils import select_usable_devices from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig -from .reporter import OptimConfig, Reporter, ReporterConfig +from .reporter import OptimConfig, ReporterConfig @dataclass @@ -34,8 +33,9 @@ class Elicit(Serializable): "use all available GPUs". normalization: The normalization method to use. Defaults to "meanonly". See `elk.training.preprocessing.normalize()` for details. - skip_baseline: Whether to skip training the baseline classifier. Defaults to - False. + supervised: Whether to train a supervised classifier, and if so, whether to + use cross-validation. Defaults to "single", which means to train a single + classifier on the training data. "cv" means to use cross-validation. debug: When in debug mode, a useful log file is saved to the memorably-named output directory. Defaults to False. """ @@ -46,17 +46,15 @@ class Elicit(Serializable): ) optim: OptimConfig = field(default_factory=OptimConfig) - num_gpus: int = -1 - normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" - skip_baseline: bool = False concatenated_layer_offset: int = 0 - # if nonzero, appends the hidden states of layer concatenated_layer_offset before debug: bool = False - out_dir: Optional[Path] = None + min_gpu_mem: int | None = None + num_gpus: int = -1 + out_dir: Path | None = None + supervised: Literal["none", "single", "cv"] = "single" def execute(self): - train_run = Train(cfg=self, out_dir=self.out_dir) - train_run.train() + Train(cfg=self, out_dir=self.out_dir).train() @dataclass @@ -78,88 +76,117 @@ def train_reporter( layer: int, devices: list[str], world_size: int = 1, - ) -> pd.Series: + ) -> pd.DataFrame: """Train a single reporter on a single layer.""" self.make_reproducible(seed=self.cfg.net.seed + layer) - device = self.get_device(devices, world_size) - x0, x1, val_x0, val_x1, train_gt, val_gt, val_lm_preds = self.prepare_data( - device, layer - ) - pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1) - - if isinstance(self.cfg.net, CcsReporterConfig): - reporter = CcsReporter(x0.shape[-1], self.cfg.net, device=device) - elif isinstance(self.cfg.net, EigenReporterConfig): - reporter = EigenReporter(x0.shape[-1], self.cfg.net, device=device) - else: - raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") + train_dict = self.prepare_data(device, layer, "train") + val_dict = self.prepare_data(device, layer, "val") - train_loss = reporter.fit(x0, x1, train_gt) - val_result = reporter.score( - val_gt, - val_x0, - val_x1, - ) + (first_train_h, train_labels, _), *rest = train_dict.values() + d = first_train_h.shape[-1] + if not all(other_h.shape[-1] == d for other_h, _, _ in rest): + raise ValueError("All datasets must have the same hidden state size") reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - if val_lm_preds is not None: - val_gt_cpu = val_gt.repeat_interleave(val_lm_preds.shape[1]).float().cpu() - val_lm_auroc = float(roc_auc_score(val_gt_cpu, val_lm_preds.flatten())) - val_lm_acc = float(accuracy_score(val_gt_cpu, val_lm_preds.flatten() > 0.5)) - else: - val_lm_auroc = None - val_lm_acc = None - - row = pd.Series( - { - "layer": layer, - "pseudo_auroc": pseudo_auroc, - "train_loss": train_loss, - **val_result._asdict(), - "lm_auroc": val_lm_auroc, - "lm_acc": val_lm_acc, - } - ) + if isinstance(self.cfg.net, CcsReporterConfig): + assert len(train_dict) == 1, "CCS only supports single-task training" + + reporter = CcsReporter(self.cfg.net, d, device=device) + train_loss = reporter.fit(first_train_h, train_labels) - if not self.cfg.skip_baseline: - lr_model = train_baseline(x0, x1, train_gt, device=device) + (val_h, val_gt, _) = next(iter(val_dict.values())) + x0, x1 = first_train_h.unbind(2) + val_x0, val_x1 = val_h.unbind(2) + pseudo_auroc = reporter.check_separability( + train_pair=(x0, x1), val_pair=(val_x0, val_x1) + ) - lr_auroc, lr_acc = evaluate_baseline(lr_model, val_x0, val_x1, val_gt) + elif isinstance(self.cfg.net, EigenReporterConfig): + # We set num_classes to None to enable training on datasets with different + # numbers of classes. Under the hood, this causes the covariance statistics + # to be simply averaged across all batches passed to update(). + reporter = EigenReporter(self.cfg.net, d, num_classes=None, device=device) + + hidden_list, label_list = [], [] + for ds_name, (train_h, train_labels, _) in train_dict.items(): + (_, v, k, _) = train_h.shape + + # Datasets can have different numbers of variants and different numbers + # of classes, so we need to flatten them here before concatenating + hidden_list.append(rearrange(train_h, "n v k d -> (n v k) d")) + label_list.append( + to_one_hot(repeat(train_labels, "n -> (n v)", v=v), k).flatten() + ) + reporter.update(train_h) - row["lr_auroc"] = lr_auroc - row["lr_acc"] = lr_acc - save_baseline(lr_dir, layer, lr_model) + pseudo_auroc = None + train_loss = reporter.fit_streaming() + reporter.platt_scale( + torch.cat(label_list), + torch.cat(hidden_list), + ) + else: + raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") + # Save reporter checkpoint to disk with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: torch.save(reporter, file) - return row + # Fit supervised logistic regression model + if self.cfg.supervised != "none": + lr_model = train_supervised( + train_dict, device=device, cv=self.cfg.supervised == "cv" + ) + with open(lr_dir / f"layer_{layer}.pt", "wb") as file: + torch.save(lr_model, file) + else: + lr_model = None + + row_buf = [] + for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items(): + val_result = reporter.score(val_gt, val_h) + row = pd.Series( + { + "dataset": ds_name, + "layer": layer, + "pseudo_auroc": pseudo_auroc, + "train_loss": train_loss, + **val_result._asdict(), + } + ) - def get_pseudo_auroc( - self, layer: int, x0: Tensor, x1: Tensor, val_x0: Tensor, val_x1: Tensor - ): - """Check the separability of the pseudo-labels at a given layer.""" + if val_lm_preds is not None: + (_, v, k, _) = val_h.shape - with torch.no_grad(): - pseudo_auroc = Reporter.check_separability( - train_pair=(x0, x1), val_pair=(val_x0, val_x1) - ) - if pseudo_auroc > 0.6: - warnings.warn( - f"The pseudo-labels at layer {layer} are linearly separable with " - f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the " - f"algorithm will not converge to a good solution." + val_gt_rep = repeat(val_gt, "n -> (n v)", v=v) + val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...") + val_lm_auroc_res = roc_auc_ci( + to_one_hot(val_gt_rep, k).long().flatten(), val_lm_preds.flatten() + ) + row["lm_auroc"] = val_lm_auroc_res.estimate + row["lm_auroc_lower"] = val_lm_auroc_res.lower + row["lm_auroc_upper"] = val_lm_auroc_res.upper + row["lm_acc"] = accuracy(val_gt_rep, val_lm_preds) + + if lr_model is not None: + lr_auroc_res, row["lr_acc"] = evaluate_supervised( + lr_model, val_h, val_gt ) + row["lr_auroc"] = lr_auroc_res.estimate + row["lr_auroc_lower"] = lr_auroc_res.lower + row["lr_auroc_upper"] = lr_auroc_res.upper + + row_buf.append(row) - return pseudo_auroc + return pd.DataFrame(row_buf) def train(self): """Train a reporter on each layer of the network.""" devices = select_usable_devices(self.cfg.num_gpus) num_devices = len(devices) - func: Callable[[int], pd.Series] = partial( + func: Callable[[int], pd.DataFrame] = partial( self.train_reporter, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index 1400a98d..13656933 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -2,27 +2,37 @@ binarize, convert_span, get_columns_all_equal, + get_dataset_name, + get_layers, + has_multiple_configs, infer_label_column, infer_num_classes, select_train_val_splits, ) from .gpu_utils import select_usable_devices from .hf_utils import instantiate_model, is_autoregressive +from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained from .tree_utils import pytree_map from .typing import assert_type, float32_to_int16, int16_to_float32 __all__ = [ + "assert_type", + "batch_cov", "binarize", "convert_span", + "cov_mean_fused", + "float32_to_int16", "get_columns_all_equal", + "get_dataset_name", + "get_layers", + "has_multiple_configs", "infer_label_column", "infer_num_classes", "instantiate_model", - "is_autoregressive", - "float32_to_int16", "int16_to_float32", + "is_autoregressive", + "pytree_map", "select_train_val_splits", "select_usable_devices", - "pytree_map", - "assert_type", + "stochastic_round_constrained", ] diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index a98a7aae..c9e8b01c 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -1,8 +1,9 @@ import copy from bisect import bisect_left, bisect_right +from functools import cache from operator import itemgetter from random import Random -from typing import Any, Iterable, List +from typing import Any, Iterable from datasets import ( ClassLabel, @@ -10,6 +11,7 @@ Features, Split, Value, + get_dataset_config_names, ) from ..promptsource.templates import Template @@ -44,6 +46,30 @@ def get_columns_all_equal(dataset: DatasetDict) -> list[str]: return pivot +def get_dataset_name(dataset: DatasetDict) -> str: + """Get the name of a `DatasetDict`.""" + builder_name, *rest = [ds.builder_name for ds in dataset.values()] + if not all(name == builder_name for name in rest): + raise ValueError( + f"All splits must have the same name; got {[builder_name, *rest]}" + ) + + config_name, *rest = [ds.config_name for ds in dataset.values()] + if not all(name == config_name for name in rest): + raise ValueError( + f"All splits must have the same config name; got {[config_name, *rest]}" + ) + + include_config = config_name and has_multiple_configs(builder_name) + return builder_name + " " + config_name if include_config else builder_name + + +@cache +def has_multiple_configs(ds_name: str) -> bool: + """Return whether a dataset has multiple configs.""" + return len(get_dataset_config_names(ds_name)) > 1 + + def select_train_val_splits(raw_splits: Iterable[str]) -> tuple[str, str]: """Return splits to use for train and validation, given an Iterable of splits.""" @@ -101,11 +127,12 @@ def infer_num_classes(label_feature: Any) -> int: ) -def get_layers(ds: DatasetDict) -> List[int]: +def get_layers(ds: DatasetDict) -> list[int]: """Get a list of indices of hidden layers given a `DatasetDict`.""" + train, _ = select_train_val_splits(ds.keys()) layers = [ int(feat[len("hidden_") :]) - for feat in ds["train"].features + for feat in ds[train].features if feat.startswith("hidden_") ] return layers diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 4c3ab331..4e97b6ee 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -1,8 +1,6 @@ import transformers from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel -from .typing import assert_type - # Ordered by preference _AUTOREGRESSIVE_SUFFIXES = [ # Encoder-decoder models @@ -16,7 +14,9 @@ def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel: """Instantiate a model string with the appropriate `Auto` class.""" model_cfg = AutoConfig.from_pretrained(model_str) - archs = assert_type(list, model_cfg.architectures) + archs = model_cfg.architectures + if not isinstance(archs, list): + return AutoModel.from_pretrained(model_str, **kwargs) for suffix in _AUTOREGRESSIVE_SUFFIXES: # Check if any of the architectures in the config end with the suffix. @@ -31,7 +31,10 @@ def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel: def is_autoregressive(model_cfg: PretrainedConfig) -> bool: """Check if a model config is autoregressive.""" - archs = assert_type(list, model_cfg.architectures) + archs = model_cfg.architectures + if not isinstance(archs, list): + return False + return any( arch_str.endswith(suffix) for arch_str in archs diff --git a/elk/math_util.py b/elk/utils/math_util.py similarity index 100% rename from elk/math_util.py rename to elk/utils/math_util.py diff --git a/pyproject.toml b/pyproject.toml index 16edd58e..6575e57a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ license = {text = "MIT License"} dependencies = [ # Added distributed.split_dataset_by_node for IterableDatasets "datasets>=2.9.0", + "einops", # Introduced numpy.typing module "numpy>=1.20.0", # This version is old, but it's needed for certain HF tokenizers to work. @@ -20,8 +21,6 @@ dependencies = [ "pandas", # Basically any version should work as long as it supports the user's CUDA version "pynvml", - # Doesn't really matter but before 1.0.0 there might be weird breaking changes - "scikit-learn>=1.0.0", # Needed for certain HF tokenizers "sentencepiece==0.1.97", # We upstreamed bugfixes for Literal types in 0.1.1 @@ -42,7 +41,8 @@ dev = [ "hypothesis", "pre-commit", "pytest", - "pyright" + "pyright", + "scikit-learn", ] [project.scripts] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 081f4a8c..00000000 --- a/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ -transformers==4.25.1 -git+https://github.com/bigscience-workshop/promptsource.git ---extra-index-url https://download.pytorch.org/whl/cu113 -torch==1.12.0+cu113 -torchvision==0.13.0+cu113 -sentencepiece==0.1.97 -protobuf==3.20.* -scikit-learn==1.2.0 -umap-learn==0.5.3 -tqdm==4.64.1 -pytest==7.2.1 diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index 58dd6c13..8977de01 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -1,40 +1,69 @@ +import pytest import torch -from elk.math_util import batch_cov, cov_mean_fused from elk.training import EigenReporter, EigenReporterConfig +from elk.utils import batch_cov, cov_mean_fused -def test_eigen_reporter(): +@pytest.mark.parametrize("track_class_means", [True, False]) +def test_eigen_reporter(track_class_means: bool): cluster_size = 5 hidden_size = 10 - num_clusters = 100 + N = 100 - x_pos = torch.randn(num_clusters, cluster_size, hidden_size, dtype=torch.float64) - x_neg = torch.randn(num_clusters, cluster_size, hidden_size, dtype=torch.float64) - x_pos1, x_pos2 = x_pos.chunk(2, dim=0) - x_neg1, x_neg2 = x_neg.chunk(2, dim=0) + x = torch.randn(N, cluster_size, 2, hidden_size, dtype=torch.float64) + x1, x2 = x.chunk(2, dim=0) + x_neg, x_pos = x.unbind(2) - reporter = EigenReporter(hidden_size, EigenReporterConfig(), dtype=torch.float64) - reporter.update(x_pos1, x_neg1) - reporter.update(x_pos2, x_neg2) + reporter = EigenReporter( + EigenReporterConfig(), + hidden_size, + dtype=torch.float64, + num_classes=2 if track_class_means else None, + ) + reporter.update(x1) + reporter.update(x2) - # Check that the streaming mean is correct - pos_mu, neg_mu = x_pos.mean(dim=(0, 1)), x_neg.mean(dim=(0, 1)) - torch.testing.assert_close(reporter.pos_mean, pos_mu) - torch.testing.assert_close(reporter.neg_mean, neg_mu) + if track_class_means: + # Check that the streaming mean is correct + neg_mu, pos_mu = x_neg.mean(dim=(0, 1)), x_pos.mean(dim=(0, 1)) - # Check that the streaming covariance is correct - pos_centroids, neg_centroids = x_pos.mean(dim=1), x_neg.mean(dim=1) - expected_var = batch_cov(pos_centroids) + batch_cov(neg_centroids) - torch.testing.assert_close(reporter.intercluster_cov, expected_var) + assert reporter.class_means is not None + torch.testing.assert_close(reporter.class_means[0], neg_mu) + torch.testing.assert_close(reporter.class_means[1], pos_mu) - # Check that the streaming invariance (intra-cluster variance) is correct - expected_invariance = cov_mean_fused(x_pos) + cov_mean_fused(x_neg) - torch.testing.assert_close(reporter.intracluster_cov, expected_invariance) + # Check that the streaming covariance is correct + neg_centroids, pos_centroids = x_neg.mean(dim=1), x_pos.mean(dim=1) + true_cov = 0.5 * (batch_cov(neg_centroids) + batch_cov(pos_centroids)) + torch.testing.assert_close(reporter.intercluster_cov, true_cov) + + # Check that the streaming negative covariance is correct + true_xcov = (neg_centroids - neg_mu).mT @ (pos_centroids - pos_mu) / N + true_xcov = 0.5 * (true_xcov + true_xcov.mT) + torch.testing.assert_close(reporter.contrastive_xcov, true_xcov) + else: + assert reporter.class_means is None + + # Check that the covariance matrices are correct. When we don't track class + # means, we expect intercluster_cov and contrastive_xcov to simply be averaged + # over each batch passed to update(). + true_xcov = 0.0 + true_cov = 0.0 + for x_i in (x1, x2): + x_neg_i, x_pos_i = x_i.unbind(2) + neg_centroids, pos_centroids = x_neg_i.mean(dim=1), x_pos_i.mean(dim=1) + true_cov += 0.5 * (batch_cov(neg_centroids) + batch_cov(pos_centroids)) - # Check that the streaming negative covariance is correct - cross_cov = (pos_centroids - pos_mu).mT @ (neg_centroids - neg_mu) / num_clusters - cross_cov = cross_cov + cross_cov.mT - torch.testing.assert_close(reporter.contrastive_xcov, cross_cov) + neg_mu_i, pos_mu_i = x_neg_i.mean(dim=(0, 1)), x_pos_i.mean(dim=(0, 1)) + xcov_asym = (neg_centroids - neg_mu_i).mT @ (pos_centroids - pos_mu_i) + true_xcov += 0.5 * (xcov_asym + xcov_asym.mT) + + torch.testing.assert_close(reporter.intercluster_cov, true_cov / 2) + torch.testing.assert_close(reporter.contrastive_xcov, true_xcov / N) + + # Check that the streaming invariance (intra-cluster variance) is correct. + # This is actually the same whether or not we track class means. + expected_invariance = 0.5 * (cov_mean_fused(x_neg) + cov_mean_fused(x_pos)) + torch.testing.assert_close(reporter.intracluster_cov, expected_invariance) - assert reporter.n == num_clusters + assert reporter.n == N diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py index a5d238fd..2e03c379 100644 --- a/tests/test_load_prompts.py +++ b/tests/test_load_prompts.py @@ -1,4 +1,4 @@ -from itertools import cycle, islice +from itertools import islice from typing import Literal import pytest @@ -10,26 +10,22 @@ @pytest.mark.filterwarnings("ignore:Unable to find a decoding function") def test_load_prompts(): def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): - prompt_ds = load_prompts( - *cfg.datasets, - split_type=split_type, - ) - prompters = [] - - for ds in cfg.datasets: - ds_name, _, config_name = ds.partition(" ") + for cfg in cfg.explode(): + ds_string = cfg.datasets[0] + prompt_ds = load_prompts(ds_string, split_type=split_type) + + ds_name, _, config_name = ds_string.partition(" ") prompter = DatasetTemplates(ds_name, config_name or None) - prompters.append(prompter) - limit = cfg.max_examples[0 if split_type == "train" else 1] - for prompter, record in zip(cycle(prompters), islice(prompt_ds, limit)): - true_template_names = prompter.all_template_names - returned_template_names = record["template_names"] + limit = cfg.max_examples[0 if split_type == "train" else 1] + for record in islice(prompt_ds, limit): + true_template_names = prompter.all_template_names + returned_template_names = record["template_names"] - # check for using the same templates - assert set(true_template_names) == set(returned_template_names) - # check for them being in the same order - assert true_template_names == true_template_names + # check for using the same templates + assert set(true_template_names) == set(returned_template_names) + # check for them being in the same order + assert true_template_names == true_template_names # the case where the dataset has 2 classes # this dataset is small diff --git a/tests/test_math.py b/tests/test_math.py index ee81914e..34984d8f 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -6,7 +6,7 @@ from hypothesis import given from hypothesis import strategies as st -from elk.math_util import batch_cov, cov_mean_fused, stochastic_round_constrained +from elk.utils import batch_cov, cov_mean_fused, stochastic_round_constrained def test_cov_mean_fused(): diff --git a/tests/test_roc_auc.py b/tests/test_roc_auc.py new file mode 100644 index 00000000..244bdb88 --- /dev/null +++ b/tests/test_roc_auc.py @@ -0,0 +1,53 @@ +import numpy as np +import torch +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import roc_auc_score + +from elk.metrics import roc_auc + + +def test_roc_auc_score(): + # Generate 1D binary classification dataset + X_1d, y_true_1d = make_classification(n_samples=1000, random_state=42) + + # Generate 2D matrix of binary classification datasets + X_2d_1, y_true_2d_1 = make_classification(n_samples=1000, random_state=43) + X_2d_2, y_true_2d_2 = make_classification(n_samples=1000, random_state=44) + + # Fit LR models and get predicted probabilities for 1D and 2D cases + lr_1d = LogisticRegression(random_state=42).fit(X_1d, y_true_1d) + y_scores_1d = lr_1d.predict_proba(X_1d)[:, 1] + + lr_2d_1 = LogisticRegression(random_state=42).fit(X_2d_1, y_true_2d_1) + y_scores_2d_1 = lr_2d_1.predict_proba(X_2d_1)[:, 1] + + lr_2d_2 = LogisticRegression(random_state=42).fit(X_2d_2, y_true_2d_2) + y_scores_2d_2 = lr_2d_2.predict_proba(X_2d_2)[:, 1] + + # Stack the datasets into 2D matrices + y_true_2d = np.vstack((y_true_2d_1, y_true_2d_2)) + y_scores_2d = np.vstack((y_scores_2d_1, y_scores_2d_2)) + + # Convert to PyTorch tensors + y_true_1d_torch = torch.tensor(y_true_1d) + y_scores_1d_torch = torch.tensor(y_scores_1d) + y_true_2d_torch = torch.tensor(y_true_2d) + y_scores_2d_torch = torch.tensor(y_scores_2d) + + # Calculate ROC AUC score using batch_roc_auc_score function for 1D and 2D cases + roc_auc_1d_torch = roc_auc(y_true_1d_torch, y_scores_1d_torch).item() + roc_auc_2d_torch = roc_auc(y_true_2d_torch, y_scores_2d_torch).numpy() + + # Calculate ROC AUC score with sklearn's roc_auc_score function for 1D and 2D cases + roc_auc_1d_sklearn = roc_auc_score(y_true_1d, y_scores_1d) + roc_auc_2d_sklearn = np.array( + [ + roc_auc_score(y_true_2d_1, y_scores_2d_1), + roc_auc_score(y_true_2d_2, y_scores_2d_2), + ] + ) + + # Assert that the results from the two implementations are almost equal + np.testing.assert_almost_equal(roc_auc_1d_torch, roc_auc_1d_sklearn) + np.testing.assert_almost_equal(roc_auc_2d_torch, roc_auc_2d_sklearn) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 87c1ac0c..c5a142c1 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -41,7 +41,7 @@ def test_output_is_roughly_balanced(): ) col = infer_label_column(dataset.features) - reservoir = BalancedSampler(dataset) + reservoir = BalancedSampler(dataset, 2) # Count the number of samples for each label counter = Counter() diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index 5bbcfef3..e61568ba 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -14,10 +14,10 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): data=Extract( model=model_path, prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), - min_gpu_mem=min_mem, # run on all layers, tiny-gpt only has 2 layers ), num_gpus=2, + min_gpu_mem=min_mem, net=CcsReporterConfig(), out_dir=tmp_path, ) @@ -25,7 +25,13 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): # get the files in the tmp_path files: list[Path] = list(tmp_path.iterdir()) created_file_names = {file.name for file in files} - expected_files = ["cfg.yaml", "metadata.yaml", "lr_models", "reporters", "eval.csv"] + expected_files = [ + "cfg.yaml", + "fingerprints.yaml", + "lr_models", + "reporters", + "eval.csv", + ] for file in expected_files: assert file in created_file_names @@ -38,10 +44,10 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): data=Extract( model=model_path, prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), - min_gpu_mem=min_mem, # run on all layers, tiny-gpt only has 2 layers ), num_gpus=2, + min_gpu_mem=min_mem, net=EigenReporterConfig(), out_dir=tmp_path, ) @@ -49,6 +55,12 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): # get the files in the tmp_path files: list[Path] = list(tmp_path.iterdir()) created_file_names = {file.name for file in files} - expected_files = ["cfg.yaml", "metadata.yaml", "lr_models", "reporters", "eval.csv"] + expected_files = [ + "cfg.yaml", + "fingerprints.yaml", + "lr_models", + "reporters", + "eval.csv", + ] for file in expected_files: assert file in created_file_names