From 16dc1ca580fed6333fdbc4e5da14c9311435b830 Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Fri, 14 Apr 2023 16:43:48 -0700 Subject: [PATCH] Multiple datasets refactor (#189) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix bug where cached hidden states aren’t used when num_gpus is different * Actually works now * Refactor handling of multiple datasets * Various fixes * Fix math tests * Fix smoke tests * All tests working ostensibly * Make CCS normalization customizable * log each dataset individually * Move pseudo AUROC stuff to CcsReporter * Make 'datasets' and 'label_columns' config options more opinionated * tiny spacing change * Allow for toggling CV * add typing to logging; rename logging * Fix eval logging bug --------- Co-authored-by: Alex Mallen --- elk/__main__.py | 5 +- elk/debug_logging.py | 57 +++++++++ elk/evaluation/evaluate.py | 75 ++++++------ elk/extraction/balanced_sampler.py | 2 +- elk/extraction/extraction.py | 52 ++++++--- elk/extraction/generator.py | 42 +++++-- elk/extraction/prompt_loading.py | 178 ++++++++++++++--------------- elk/files.py | 28 ----- elk/logging.py | 47 -------- elk/run.py | 124 ++++++++++---------- elk/training/__init__.py | 6 +- elk/training/baseline.py | 62 ---------- elk/training/ccs_reporter.py | 83 ++++++++++++-- elk/training/eigen_reporter.py | 32 ++++-- elk/training/normalizer.py | 63 ++++++++++ elk/training/preprocessing.py | 55 --------- elk/training/reporter.py | 81 +------------ elk/training/supervised.py | 50 ++++++++ elk/training/train.py | 156 +++++++++++++------------ elk/utils/__init__.py | 18 ++- elk/utils/data_utils.py | 30 ++++- elk/utils/hf_utils.py | 11 +- elk/{ => utils}/math_util.py | 0 pyproject.toml | 1 + tests/test_eigen_reporter.py | 2 +- tests/test_load_prompts.py | 32 +++--- tests/test_math.py | 2 +- tests/test_smoke_elicit.py | 20 +++- 28 files changed, 687 insertions(+), 627 deletions(-) create mode 100644 elk/debug_logging.py delete mode 100644 elk/logging.py delete mode 100644 elk/training/baseline.py create mode 100644 elk/training/normalizer.py delete mode 100644 elk/training/preprocessing.py create mode 100644 elk/training/supervised.py rename elk/{ => utils}/math_util.py (100%) diff --git a/elk/__main__.py b/elk/__main__.py index 5304f5aa..faa044f7 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -5,7 +5,6 @@ from simple_parsing import ArgumentParser from elk.evaluation.evaluate import Eval -from elk.extraction.extraction import Extract from elk.training.train import Elicit @@ -13,14 +12,14 @@ class Command: """Some top-level command""" - command: Elicit | Eval | Extract + command: Elicit | Eval def execute(self): return self.command.execute() def run(): - parser = ArgumentParser(add_help=False) + parser = ArgumentParser(add_help=False, add_config_path_arg=True) parser.add_arguments(Command, dest="run") args = parser.parse_args() run: Command = args.run diff --git a/elk/debug_logging.py b/elk/debug_logging.py new file mode 100644 index 00000000..b43650ab --- /dev/null +++ b/elk/debug_logging.py @@ -0,0 +1,57 @@ +import logging +from pathlib import Path + +from datasets import DatasetDict + +from .utils import get_dataset_name, select_train_val_splits + + +def save_debug_log(datasets: list[DatasetDict], out_dir: Path) -> None: + """ + Save a debug log to the output directory. This is useful for debugging + training issues. + """ + + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(levelname)s:\n%(message)s", + filename=out_dir / "debug.log", + filemode="w", + ) + + for ds in datasets: + logging.info( + "=========================================\n" + f"Dataset: {get_dataset_name(ds)}\n" + "=========================================" + ) + + train_split, val_split = select_train_val_splits(ds) + text_inputs = ds[val_split][0]["text_inputs"] + template_ids = ds[val_split][0]["variant_ids"] + label = ds[val_split][0]["label"] + + # log the train size and val size + logging.info(f"Train size: {len(ds[train_split])}") + logging.info(f"Val size: {len(ds[val_split])}") + + templates_text = f"{len(text_inputs)} templates used:\n" + trailing_whitespace = False + for (text0, text1), id in zip(text_inputs, template_ids): + templates_text += ( + f'***---TEMPLATE "{id}"---***\n' + f"{'false' if label else 'true'}:\n" + f'"""{text0}"""\n' + f"{'true' if label else 'false'}:\n" + f'"""{text1}"""\n\n\n' + ) + if text0[-1].isspace() or text1[-1].isspace(): + trailing_whitespace = True + if trailing_whitespace: + logging.warning( + "Some inputs to the model have trailing whitespace! " + "Check that the jinja templates are not adding " + "trailing whitespace. If `token_loc` is 'last', this " + "will extract hidden states from the whitespace token." + ) + logging.info(templates_text) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 6aca58f5..7ca0716d 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Literal, Optional +from typing import Callable import pandas as pd import torch @@ -11,7 +11,7 @@ from ..files import elk_reporter_dir from ..run import Run from ..training import Reporter -from ..training.baseline import evaluate_baseline, load_baseline +from ..training.supervised import evaluate_supervised from ..utils import select_usable_devices @@ -28,26 +28,25 @@ class Eval(Serializable): `elk.training.preprocessing.normalize()` for details. num_gpus: The number of GPUs to use. Defaults to -1, which means "use all available GPUs". + skip_supervised: Whether to skip evaluation of the supervised classifier. debug: When in debug mode, a useful log file is saved to the memorably-named output directory. Defaults to False. """ data: Extract source: str = field(positional=True) - normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" + concatenated_layer_offset: int = 0 debug: bool = False - out_dir: Optional[Path] = None + min_gpu_mem: int | None = None num_gpus: int = -1 - skip_baseline: bool = False - concatenated_layer_offset: int = 0 + out_dir: Path | None = None + skip_supervised: bool = False def execute(self): - datasets = self.data.prompts.datasets - transfer_dir = elk_reporter_dir() / self.source / "transfer_eval" - for dataset in datasets: + for dataset in self.data.prompts.datasets: run = Evaluate(cfg=self, out_dir=transfer_dir / dataset) run.evaluate() @@ -58,14 +57,10 @@ class Evaluate(Run): def evaluate_reporter( self, layer: int, devices: list[str], world_size: int = 1 - ) -> pd.Series: + ) -> pd.DataFrame: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) - - _, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data( - device, - layer, - ) + val_output = self.prepare_data(device, layer, "val") experiment_dir = elk_reporter_dir() / self.cfg.source @@ -73,40 +68,44 @@ def evaluate_reporter( reporter: Reporter = torch.load(reporter_path, map_location=device) reporter.eval() - test_result = reporter.score( - test_labels, - test_x0, - test_x1, - ) - - stats_row = pd.Series( - { - "layer": layer, - **test_result._asdict(), - } - ) + row_buf = [] + for ds_name, (val_x0, val_x1, val_gt, _) in val_output.items(): + val_result = reporter.score( + val_gt, + val_x0, + val_x1, + ) - lr_dir = experiment_dir / "lr_models" - if not self.cfg.skip_baseline and lr_dir.exists(): - lr_model = load_baseline(lr_dir, layer) - lr_model.eval() - lr_auroc, lr_acc = evaluate_baseline( - lr_model.cuda(), test_x0.cuda(), test_x1.cuda(), test_labels + stats_row = pd.Series( + { + "dataset": ds_name, + "layer": layer, + **val_result._asdict(), + } ) - stats_row["lr_auroc"] = lr_auroc - stats_row["lr_acc"] = lr_acc + lr_dir = experiment_dir / "lr_models" + if not self.cfg.skip_supervised and lr_dir.exists(): + with open(lr_dir / f"layer_{layer}.pt", "rb") as f: + lr_model = torch.load(f, map_location=device).eval() + + lr_auroc, lr_acc = evaluate_supervised(lr_model, val_x0, val_x1, val_gt) + + stats_row["lr_auroc"] = lr_auroc + stats_row["lr_acc"] = lr_acc + + row_buf.append(stats_row) - return stats_row + return pd.DataFrame(row_buf) def evaluate(self): """Evaluate the reporter on all layers.""" devices = select_usable_devices( - self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem + self.cfg.num_gpus, min_memory=self.cfg.min_gpu_mem ) num_devices = len(devices) - func: Callable[[int], pd.Series] = partial( + func: Callable[[int], pd.DataFrame] = partial( self.evaluate_reporter, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 2ea4815e..4a287320 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -6,8 +6,8 @@ from datasets import Features, IterableDataset from torch.utils.data import IterableDataset as TorchIterableDataset -from ..math_util import stochastic_round_constrained from ..utils import infer_label_column +from ..utils.math_util import stochastic_round_constrained from ..utils.typing import assert_type diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index a9abd6a6..e022e5ce 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -4,7 +4,7 @@ from copy import copy from dataclasses import InitVar, dataclass from itertools import islice -from typing import Any, Iterable, Literal, Optional +from typing import Any, Iterable, Literal import torch from datasets import ( @@ -23,6 +23,7 @@ from transformers import AutoConfig, AutoTokenizer from transformers.modeling_outputs import Seq2SeqLMOutput +from ..promptsource import DatasetTemplates from ..utils import ( assert_type, convert_span, @@ -48,7 +49,6 @@ class Extract(Serializable): layer_stride: Shortcut for setting `layers` to `range(0, num_layers, stride)`. token_loc: The location of the token to extract hidden states from. Can be either "first", "last", or "mean". Defaults to "last". - min_gpu_mem: Minimum amount of free memory (in bytes) required to select a GPU. """ prompts: PromptConfig @@ -57,8 +57,6 @@ class Extract(Serializable): layers: tuple[int, ...] = () layer_stride: InitVar[int] = 1 token_loc: Literal["first", "last", "mean"] = "last" - min_gpu_mem: Optional[int] = None - num_gpus: int = -1 def __post_init__(self, layer_stride: int): if self.layers and layer_stride > 1: @@ -74,8 +72,16 @@ def __post_init__(self, layer_stride: int): ) self.layers = tuple(range(0, config.num_hidden_layers, layer_stride)) - def execute(self): - extract(cfg=self, num_gpus=self.num_gpus) + def explode(self) -> list["Extract"]: + """Explode this config into a list of configs, one for each layer.""" + copies = [] + + for prompt_cfg in self.prompts.explode(): + cfg = copy(self) + cfg.prompts = prompt_cfg + copies.append(cfg) + + return copies @torch.no_grad() @@ -94,8 +100,11 @@ def extract_hiddens( if rank != 0: logging.disable(logging.CRITICAL) + ds_names = cfg.prompts.datasets + assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." + prompt_ds = load_prompts( - *cfg.prompts.datasets, + ds_names[0], split_type=split_type, stream=cfg.prompts.stream, rank=rank, @@ -240,14 +249,19 @@ def _extraction_worker(**kwargs): yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) -def extract(cfg: "Extract", num_gpus: int = -1) -> DatasetDict: +def extract( + cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None +) -> DatasetDict: """Extract hidden states from a model and return a `DatasetDict` containing them.""" def get_splits() -> SplitDict: available_splits = assert_type(SplitDict, info.splits) train_name, val_name = select_train_val_splits(available_splits) - print(f"Using '{train_name}' for training and '{val_name}' for validation") - + print( + # Cyan color for dataset name + f"\033[36m{info.builder_name}\033[0m: using '{train_name}' for training and" + f" '{val_name}' for validation" + ) limit_list = cfg.prompts.max_examples return SplitDict( @@ -263,11 +277,15 @@ def get_splits() -> SplitDict: ) model_cfg = AutoConfig.from_pretrained(cfg.model) - num_variants = cfg.prompts.num_variants ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") info = get_dataset_config_info(ds_name, config_name or None) + num_variants = cfg.prompts.num_variants + if num_variants < 0: + prompter = DatasetTemplates(ds_name, config_name) + num_variants = len(prompter.templates) + layer_cols = { f"hidden_{layer}": Array3D( dtype="int16", @@ -297,22 +315,18 @@ def get_splits() -> SplitDict: length=num_variants, ) - devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem) - - # Prevent the GPU-related config options from invalidating the cache - _cfg = copy(cfg) - _cfg.min_gpu_mem = None - _cfg.num_gpus = -1 - + devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem) builders = { split_name: _GeneratorBuilder( + builder_name=info.builder_name, + config_name=info.config_name, cache_dir=None, features=Features({**layer_cols, **other_cols}), generator=_extraction_worker, split_name=split_name, split_info=split_info, gen_kwargs=dict( - cfg=[_cfg] * len(devices), + cfg=[cfg] * len(devices), device=devices, rank=list(range(len(devices))), split_type=[split_name] * len(devices), diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index e3cad0e5..86e65e08 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -1,17 +1,22 @@ from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional - -import datasets -from datasets import Features +from typing import Any, Callable + +from datasets import ( + BuilderConfig, + DatasetInfo, + Features, + GeneratorBasedBuilder, + SplitInfo, +) from datasets.splits import NamedSplit @dataclass -class _GeneratorConfig(datasets.BuilderConfig): - generator: Optional[Callable] = None +class _GeneratorConfig(BuilderConfig): + generator: Callable | None = None gen_kwargs: dict[str, Any] = field(default_factory=dict) - features: Optional[datasets.Features] = None + features: Features | None = None def create_config_id( self, config_kwargs: dict, custom_features: Features | None @@ -37,28 +42,41 @@ class _SplitGenerator: """ name: str - split_info: datasets.SplitInfo - gen_kwargs: Dict = field(default_factory=dict) + split_info: SplitInfo + gen_kwargs: dict = field(default_factory=dict) def __post_init__(self): self.name = str(self.name) # Make sure we convert NamedSplits in strings NamedSplit(self.name) # check that it's a valid split name -class _GeneratorBuilder(datasets.GeneratorBasedBuilder): +class _GeneratorBuilder(GeneratorBasedBuilder): """Patched version of `datasets.Generator` allowing for splits besides `train`""" BUILDER_CONFIG_CLASS = _GeneratorConfig config: _GeneratorConfig - def __init__(self, split_name: str, split_info: datasets.SplitInfo, **kwargs): + def __init__( + self, + builder_name: str | None, + config_name: str | None, + split_name: str, + split_info: SplitInfo, + **kwargs, + ): self.split_name = split_name self.split_info = split_info super().__init__(**kwargs) + # Weirdly we need to set DatasetInfo.builder_name and DatasetInfo.config_name + # here, not in _info, because super().__init__ modifies them + self.info.builder_name = builder_name + self.info.config_name = config_name + def _info(self): - return datasets.DatasetInfo(features=self.config.features) + # Use the same builder and config name as the original builder + return DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): return [ diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 2250a3c2..36b4d552 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -1,5 +1,7 @@ from collections import Counter +from copy import deepcopy from dataclasses import dataclass +from itertools import zip_longest from random import Random from typing import Any, Iterator, Literal, Optional @@ -26,9 +28,8 @@ class PromptConfig(Serializable): """ Args: - dataset: Space-delimited name of the HuggingFace dataset to use, e.g. + dataset: List of space-delimited names of the HuggingFace dataset to use, e.g. `"super_glue boolq"` or `"imdb"`. - balance: Whether to force class balance in the dataset using undersampling. data_dir: The directory to use for caching the dataset. Defaults to `~/.cache/huggingface/datasets`. label_column: The column containing the labels. By default, we infer this from @@ -47,9 +48,8 @@ class PromptConfig(Serializable): """ datasets: list[str] = field(positional=True) - balance: bool = False - data_dir: Optional[str] = None - label_column: Optional[str] = None + data_dirs: list[str] = field(default_factory=list) + label_columns: list[str] = field(default_factory=list) max_examples: list[int] = field(default_factory=lambda: [750, 250]) num_shots: int = 0 num_variants: int = -1 @@ -69,9 +69,41 @@ def __post_init__(self): if len(self.max_examples) == 1: self.max_examples *= 2 + # Broadcast the dataset name to all data_dirs and label_columns + if len(self.data_dirs) == 1: + self.data_dirs *= len(self.datasets) + elif self.data_dirs and len(self.data_dirs) != len(self.datasets): + raise ValueError( + "data_dirs should be a list of length 0, 1, or len(datasets)," + f" but got {len(self.data_dirs)}" + ) + + if len(self.label_columns) == 1: + self.label_columns *= len(self.datasets) + elif self.label_columns and len(self.label_columns) != len(self.datasets): + raise ValueError( + "label_columns should be a list of length 0, 1, or len(datasets)," + f" but got {len(self.label_columns)}" + ) + + def explode(self) -> list["PromptConfig"]: + """Explode the config into a list of configs, one for each dataset.""" + copies = [] + + for ds, data_dir, col in zip_longest( + self.datasets, self.data_dirs, self.label_columns + ): + copy = deepcopy(self) + copy.datasets = [ds] + copy.data_dirs = [data_dir] if data_dir else [] + copy.label_columns = [col] if col else [] + copies.append(copy) + + return copies + def load_prompts( - *dataset_strings: str, + ds_string: str, num_shots: int = 0, num_variants: int = -1, seed: int = 42, @@ -83,7 +115,7 @@ def load_prompts( """Load a dataset full of prompts generated from the specified datasets. Args: - dataset_strings: Space-delimited names of the HuggingFace datasets to use, + ds_string: Space-delimited name of the HuggingFace datasets to use, e.g. `"super_glue boolq"` or `"imdb"`. num_shots: The number of examples to use in few-shot prompts. If zero, prompts are zero-shot. @@ -96,101 +128,65 @@ def load_prompts( Returns: An iterable dataset of prompts. """ - prompters = [] - raw_datasets = [] - train_datasets = [] - rng = Random(seed) + ds_name, _, config_name = ds_string.partition(" ") + prompter = DatasetTemplates(ds_name, config_name) - # First load the datasets and prompters. We need to know the minimum number of - # templates for any dataset in order to make sure we don't run out of prompts. - for ds_string in dataset_strings: - ds_name, _, config_name = ds_string.partition(" ") - prompters.append(DatasetTemplates(ds_name, config_name)) + ds_dict = assert_type( + dict, load_dataset(ds_name, config_name or None, streaming=stream) + ) + train_name, val_name = select_train_val_splits(ds_dict) + split_name = val_name if split_type == "val" else train_name - ds_dict = assert_type( - dict, load_dataset(ds_name, config_name or None, streaming=stream) - ) - train_name, val_name = select_train_val_splits(ds_dict) - split_name = val_name if split_type == "val" else train_name - - # Note that when streaming we can only approximately shuffle the dataset - # using a buffer. Streaming shuffling is NOT an adequate shuffle for - # datasets like IMDB, which are sorted by label. - bad_streaming_datasets = ["imdb"] - assert not ( - stream and ds_name in bad_streaming_datasets - ), f"Streaming is not supported for {ds_name}." - split = ds_dict[split_name].shuffle(seed=seed) - train_ds = ds_dict[train_name].shuffle(seed=seed) - if not stream: - split = assert_type(Dataset, split) - split = split.to_iterable_dataset().cast(split.features) - - # only keep the datapoints relevant to the current process - if world_size > 1: - # This prints to stdout which is slightly annoying - split = split_dataset_by_node( - dataset=split, rank=rank, world_size=world_size - ) + ds = ds_dict[split_name].shuffle(seed=seed) + train_ds = ds_dict[train_name].shuffle(seed=seed) + if not stream: + ds = assert_type(Dataset, ds) + ds = ds.to_iterable_dataset().cast(ds.features) - raw_datasets.append(split) - train_datasets.append(train_ds) + # only keep the datapoints relevant to the current process + if world_size > 1: + # This prints to stdout which is slightly annoying + ds = split_dataset_by_node(dataset=ds, rank=rank, world_size=world_size) - min_num_templates = min(len(prompter.templates) for prompter in prompters) + num_templates = len(prompter.templates) num_variants = ( - min_num_templates - if num_variants == -1 - else min(num_variants, min_num_templates) + num_templates if num_variants == -1 else min(num_variants, num_templates) ) assert num_variants > 0 if rank == 0: print(f"Using {num_variants} variants of each prompt") - ds_iterators = [iter(ds) for ds in raw_datasets] - while True: # terminates when the first dataset runs out of examples - for ds_iterator, ds, train_ds, prompter in zip( - ds_iterators, raw_datasets, train_datasets, prompters - ): - label_column = infer_label_column(ds.features) - num_classes = infer_num_classes(ds.features[label_column]) - - # Remove everything except the label column - extra_cols = list(assert_type(Features, ds.features)) - extra_cols.remove(label_column) - - if label_column != "label": - ds = ds.rename_column(label_column, "label") - if num_shots > 0: - fewshot = FewShotSampler( - train_ds, # TODO: not iterator - num_shots=num_shots, - rng=rng, - ) - fewshot_iter = iter(fewshot) - else: - fewshot_iter = None - - try: - example = next(ds_iterator) - except StopIteration: - return - - example = _convert_to_prompts( - example, - label_column=label_column, - num_classes=num_classes, - num_variants=num_variants, - prompter=prompter, - rng=rng, - fewshot_iter=fewshot_iter, - ) - - # Add the builder and config name to the records directly to make - # sure we don't forget what dataset they came from. - example["builder_name"] = ds.info.builder_name - example["config_name"] = ds.info.config_name + label_column = infer_label_column(ds.features) + num_classes = infer_num_classes(ds.features[label_column]) + rng = Random(seed) - yield example + if num_shots > 0: + fewshot = FewShotSampler( + train_ds, # TODO: not iterator + num_shots=num_shots, + rng=rng, + ) + fewshot_iter = iter(fewshot) + else: + fewshot_iter = None + + # Remove everything except the label column + extra_cols = list(assert_type(Features, ds.features)) + extra_cols.remove(label_column) + + if label_column != "label": + ds = ds.rename_column(label_column, "label") + + for example in ds: + yield _convert_to_prompts( + example, + label_column=label_column, + num_classes=num_classes, + num_variants=num_variants, + prompter=prompter, + rng=rng, + fewshot_iter=fewshot_iter, + ) def _convert_to_prompts( diff --git a/elk/files.py b/elk/files.py index a9da71e9..4435dee1 100644 --- a/elk/files.py +++ b/elk/files.py @@ -5,9 +5,6 @@ import random from pathlib import Path -import yaml -from simple_parsing import Serializable - def elk_reporter_dir() -> Path: """Return the directory where reporter checkpoints and logs are stored.""" @@ -41,28 +38,3 @@ def memorably_named_dir(parent: Path): out_dir = parent / sub_dir out_dir.mkdir(parents=True, exist_ok=True) return out_dir - - -def save_config(cfg: Serializable, out_dir: Path): - """Save the config to a file""" - - path = out_dir / "cfg.yaml" - with open(path, "w") as f: - cfg.dump_yaml(f) - - return path - - -def save_meta(dataset, out_dir: Path): - """Save the meta data to a file""" - - meta = { - "dataset_fingerprints": { - split: dataset[split]._fingerprint for split in dataset.keys() - } - } - path = out_dir / "metadata.yaml" - with open(path, "w") as meta_f: - yaml.dump(meta, meta_f) - - return path diff --git a/elk/logging.py b/elk/logging.py deleted file mode 100644 index 706055bd..00000000 --- a/elk/logging.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging - -from .utils import select_train_val_splits - - -def save_debug_log(ds, out_dir): - """ - Save a debug log to the output directory. This is useful for debugging - training issues. - """ - - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s %(levelname)s:\n%(message)s", - filename=out_dir / "debug.log", - filemode="w", - ) - - train_split, val_split = select_train_val_splits(ds) - text_inputs = ds[val_split][0]["text_inputs"] - template_ids = ds[val_split][0]["variant_ids"] - label = ds[val_split][0]["label"] - - # log the train size and val size - logging.info(f"Train size: {len(ds[train_split])}") - logging.info(f"Val size: {len(ds[val_split])}") - - templates_text = f"{len(text_inputs)} templates used:\n" - trailing_whitespace = False - for (text0, text1), id in zip(text_inputs, template_ids): - templates_text += ( - f'***---TEMPLATE "{id}"---***\n' - f"{'false' if label else 'true'}:\n" - f'"""{text0}"""\n' - f"{'true' if label else 'false'}:\n" - f'"""{text1}"""\n\n\n' - ) - if text0[-1].isspace() or text1[-1].isspace(): - trailing_whitespace = True - if trailing_whitespace: - logging.warning( - "Some inputs to the model have trailing whitespace! " - "Check that the jinja templates are not adding " - "trailing whitespace. If `token_loc` is 'last', this " - "will extract hidden states from the whitespace token." - ) - logging.info(templates_text) diff --git a/elk/run.py b/elk/run.py index af75c597..36ea8ca8 100644 --- a/elk/run.py +++ b/elk/run.py @@ -3,27 +3,27 @@ from abc import ABC from dataclasses import dataclass, field from pathlib import Path -from typing import ( - TYPE_CHECKING, - Callable, - Optional, - Union, -) +from typing import TYPE_CHECKING, Callable, Literal, Union import numpy as np import pandas as pd import torch import torch.multiprocessing as mp +import yaml from datasets import DatasetDict from torch import Tensor from tqdm import tqdm +from .debug_logging import save_debug_log from .extraction import extract -from .files import elk_reporter_dir, memorably_named_dir, save_config, save_meta -from .logging import save_debug_log -from .training.preprocessing import normalize -from .utils import assert_type, int16_to_float32 -from .utils.data_utils import get_layers, select_train_val_splits +from .files import elk_reporter_dir, memorably_named_dir +from .utils import ( + assert_type, + get_dataset_name, + get_layers, + int16_to_float32, + select_train_val_splits, +) if TYPE_CHECKING: from .evaluation.evaluate import Eval @@ -33,12 +33,14 @@ @dataclass class Run(ABC): cfg: Union["Elicit", "Eval"] - out_dir: Optional[Path] = None - dataset: DatasetDict = field(init=False) + out_dir: Path | None = None + datasets: list[DatasetDict] = field(init=False) def __post_init__(self): - # Extract the hidden states first if necessary - self.dataset = extract(self.cfg.data, num_gpus=self.cfg.num_gpus) + self.datasets = [ + extract(cfg, num_gpus=self.cfg.num_gpus, min_gpu_mem=self.cfg.min_gpu_mem) + for cfg in self.cfg.data.explode() + ] if self.out_dir is None: # Save in a memorably-named directory inside of @@ -52,8 +54,21 @@ def __post_init__(self): print(f"Output directory at \033[1m{self.out_dir}\033[0m") self.out_dir.mkdir(parents=True, exist_ok=True) - save_config(self.cfg, self.out_dir) - save_meta(self.dataset, self.out_dir) + path = self.out_dir / "cfg.yaml" + with open(path, "w") as f: + self.cfg.dump_yaml(f) + + path = self.out_dir / "fingerprints.yaml" + with open(path, "w") as meta_f: + yaml.dump( + { + get_dataset_name(ds): { + split: ds[split]._fingerprint for split in ds.keys() + } + for ds in self.datasets + }, + meta_f, + ) def make_reproducible(self, seed: int): """Make the run reproducible by setting the random seed.""" @@ -70,55 +85,42 @@ def get_device(self, devices, world_size: int) -> str: return device def prepare_data( - self, - device: str, - layer: int, - ) -> tuple: - """Prepare the data for training and validation.""" - - with self.dataset.formatted_as("torch", device=device, dtype=torch.int16): - train_split, val_split = select_train_val_splits(self.dataset) - train, val = self.dataset[train_split], self.dataset[val_split] - - train_labels = assert_type(Tensor, train["label"]) - val_labels = assert_type(Tensor, val["label"]) - - # Note: currently we're just upcasting to float32 - # so we don't have to deal with - # grad scaling (which isn't supported for LBFGS), - # while the hidden states are - # saved in float16 to save disk space. - # In the future we could try to use mixed - # precision training in at least some cases. - train_h, val_h = normalize( - int16_to_float32(assert_type(torch.Tensor, train[f"hidden_{layer}"])), - int16_to_float32(assert_type(torch.Tensor, val[f"hidden_{layer}"])), - method=self.cfg.normalization, - ) + self, device: str, layer: int, split_type: Literal["train", "val"] + ) -> dict[str, tuple[Tensor, Tensor, Tensor, np.ndarray | None]]: + """Prepare data for the specified layer and split type.""" + out = {} + + for ds in self.datasets: + train_name, val_name = select_train_val_splits(ds) + key = train_name if split_type == "train" else val_name - x0, x1 = train_h.unbind(dim=-2) - val_x0, val_x1 = val_h.unbind(dim=-2) + split = ds[key].with_format("torch", device=device, dtype=torch.int16) + labels = assert_type(Tensor, split["label"]) + val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) + x0, x1 = val_h.unbind(dim=-2) - with self.dataset.formatted_as("numpy"): - has_preds = "model_preds" in val.features - val_lm_preds = val["model_preds"] if has_preds else None + with split.formatted_as("numpy"): + has_preds = "model_preds" in split.features + lm_preds = split["model_preds"] if has_preds else None - return x0, x1, val_x0, val_x1, train_labels, val_labels, val_lm_preds + ds_name = get_dataset_name(ds) + out[ds_name] = (x0, x1, labels, lm_preds) + + return out def concatenate(self, layers): """Concatenate hidden states from a previous layer.""" for layer in range(self.cfg.concatenated_layer_offset, len(layers)): - layers[layer] = layers[layer] + [ - layers[layer][0] - self.cfg.concatenated_layer_offset - ] + layers[layer] += [layers[layer][0] - self.cfg.concatenated_layer_offset] + return layers def apply_to_layers( self, - func: Callable[[int], pd.Series], + func: Callable[[int], pd.DataFrame], num_devices: int, ): - """Apply a function to each layer of the dataset in parallel + """Apply a function to each layer of the datasets in parallel and writes the results to a CSV file. Args: @@ -128,7 +130,8 @@ def apply_to_layers( """ self.out_dir = assert_type(Path, self.out_dir) - layers: list[int] = get_layers(self.dataset) + layers, *rest = [get_layers(ds) for ds in self.datasets] + assert all(x == layers for x in rest), "All datasets must have the same layers" if self.cfg.concatenated_layer_offset > 0: layers = self.concatenate(layers) @@ -137,14 +140,15 @@ def apply_to_layers( ctx = mp.get_context("spawn") with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: mapper = pool.imap_unordered if num_devices > 1 else map - row_buf = [] + df_buf = [] try: - for row in tqdm(mapper(func, layers), total=len(layers)): - row_buf.append(row) + for df in tqdm(mapper(func, layers), total=len(layers)): + df_buf.append(df) finally: # Make sure the CSV is written even if we crash or get interrupted - df = pd.DataFrame(row_buf).sort_values(by="layer") - df.to_csv(f, index=False) + if df_buf: + df = pd.concat(df_buf).sort_values(by="layer") + df.to_csv(f, index=False) if self.cfg.debug: - save_debug_log(self.dataset, self.out_dir) + save_debug_log(self.datasets, self.out_dir) diff --git a/elk/training/__init__.py b/elk/training/__init__.py index 41264179..6428c3c5 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -1,13 +1,15 @@ from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig +from .normalizer import Normalizer from .reporter import OptimConfig, Reporter, ReporterConfig __all__ = [ - "Reporter", - "ReporterConfig", "CcsReporter", "CcsReporterConfig", "EigenReporter", "EigenReporterConfig", + "Normalizer", "OptimConfig", + "Reporter", + "ReporterConfig", ] diff --git a/elk/training/baseline.py b/elk/training/baseline.py deleted file mode 100644 index 2c9542a6..00000000 --- a/elk/training/baseline.py +++ /dev/null @@ -1,62 +0,0 @@ -import pickle -from pathlib import Path -from typing import Tuple - -import torch -from sklearn.metrics import accuracy_score, roc_auc_score -from torch import Tensor - -from ..utils.typing import assert_type -from .classifier import Classifier - -# TODO: Create class for baseline? - - -def evaluate_baseline( - lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_labels: Tensor -) -> Tuple[float, float]: - X = torch.cat([val_x0, val_x1]) - d = X.shape[-1] - X_val = X.view(-1, d) - with torch.no_grad(): - lr_preds = lr_model(X_val).sigmoid().cpu() - - val_labels_aug = ( - torch.cat([val_labels, 1 - val_labels]).repeat_interleave(val_x0.shape[1]) - ).cpu() - - lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) - lr_auroc = roc_auc_score(val_labels_aug, lr_preds) - - return assert_type(float, lr_auroc), assert_type(float, lr_acc) - - -def train_baseline( - x0: Tensor, - x1: Tensor, - train_labels: Tensor, - device: str, -) -> Classifier: - # repeat_interleave makes `num_variants` copies of each label, all within a - # single dimension of size `num_variants * 2 * n`, such that the labels align - # with X.view(-1, X.shape[-1]) - train_labels_aug = torch.cat([train_labels, 1 - train_labels]).repeat_interleave( - x0.shape[1] - ) - - X = torch.cat([x0, x1]).squeeze() - d = X.shape[-1] - lr_model = Classifier(d, device=device) - lr_model.fit_cv(X.view(-1, d), train_labels_aug) - - return lr_model - - -def save_baseline(lr_dir: Path, layer: int, lr_model: Classifier): - with open(lr_dir / f"layer_{layer}.pt", "wb") as file: - pickle.dump(lr_model, file) - - -def load_baseline(lr_dir: Path, layer: int) -> Classifier: - with open(lr_dir / f"layer_{layer}.pt", "rb") as file: - return pickle.load(file) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 258f2bc4..b3043766 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -7,12 +7,15 @@ import torch import torch.nn as nn +from sklearn.metrics import roc_auc_score from torch import Tensor from torch.nn.functional import binary_cross_entropy as bce from ..parsing import parse_loss from ..utils.typing import assert_type +from .classifier import Classifier from .losses import LOSSES +from .normalizer import Normalizer from .reporter import Reporter, ReporterConfig @@ -34,6 +37,7 @@ class CcsReporterConfig(ReporterConfig): Example: --loss 1.0*consistency_squared 0.5*prompt_var corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var. Defaults to "ccs_prompt_var". + normalization: The kind of normalization to apply to the hidden states. num_layers: The number of layers in the MLP. Defaults to 1. pre_ln: Whether to include a LayerNorm module before the first linear layer. Defaults to False. @@ -85,13 +89,17 @@ def __init__( self, in_features: int, cfg: CcsReporterConfig, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, ): - super().__init__(in_features, cfg, device=device, dtype=dtype) + super().__init__() + self.config = cfg hidden_size = cfg.hidden_size or 4 * in_features // 3 + self.neg_norm = Normalizer((in_features,), device=device, dtype=dtype) + self.pos_norm = Normalizer((in_features,), device=device, dtype=dtype) + self.probe = nn.Sequential( nn.Linear( in_features, @@ -120,6 +128,56 @@ def __init__( ) ) + def check_separability( + self, + train_pair: tuple[Tensor, Tensor], + val_pair: tuple[Tensor, Tensor], + ) -> float: + """Measure how linearly separable the pseudo-labels are for a contrast pair. + + Args: + train_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + contrastive representations. Used for training the classifier. + val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + contrastive representations. Used for evaluating the classifier. + + Returns: + The AUROC of a linear classifier fit on the pseudo-labels. + """ + _x0, _x1 = train_pair + _val_x0, _val_x1 = val_pair + + x0, x1 = self.neg_norm(_x0), self.pos_norm(_x1) + val_x0, val_x1 = self.neg_norm(_val_x0), self.pos_norm(_val_x1) + + pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore + pseudo_train_labels = torch.cat( + [ + x0.new_zeros(x0.shape[0]), + x0.new_ones(x0.shape[0]), + ] + ).repeat_interleave( + x0.shape[1] + ) # make num_variants copies of each pseudo-label + pseudo_val_labels = torch.cat( + [ + val_x0.new_zeros(val_x0.shape[0]), + val_x0.new_ones(val_x0.shape[0]), + ] + ).repeat_interleave(val_x0.shape[1]) + + pseudo_clf.fit( + # b v d -> (b v) d + torch.cat([x0, x1]).flatten(0, 1), + pseudo_train_labels, + ) + with torch.no_grad(): + pseudo_preds = pseudo_clf( + # b v d -> (b v) d + torch.cat([val_x0, val_x1]).flatten(0, 1) + ) + return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu())) + def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: loss = sum( LOSSES[name](logit0, logit1, coef) @@ -164,7 +222,9 @@ def forward(self, x: Tensor) -> Tensor: """Return the raw score output of the probe on `x`.""" return self.probe(x).squeeze(-1) - def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: + def predict(self, x_neg: Tensor, x_pos: Tensor) -> Tensor: + x_neg = self.neg_norm(x_neg) + x_pos = self.pos_norm(x_pos) return 0.5 * (self(x_pos).sigmoid() + (1 - self(x_neg).sigmoid())) def loss( @@ -213,14 +273,14 @@ def loss( def fit( self, - x_pos: Tensor, x_neg: Tensor, + x_pos: Tensor, labels: Optional[Tensor] = None, ) -> float: - """Fit the probe to the contrast pair (x0, x1). + """Fit the probe to the contrast pair (neg, pos). Args: - contrast_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + contrast_pair: A tuple of tensors, (neg, pos), where x0 and x1 are the contrastive representations. labels: The labels of the contrast pair. Defaults to None. @@ -231,8 +291,9 @@ def fit( ValueError: If `optimizer` is not "adam" or "lbfgs". RuntimeError: If the best loss is not finite. """ - # TODO: Implement normalization here to fix issue #96 - # self.update(x_pos, x_neg) + # Fit normalizers + self.pos_norm.fit(x_pos) + self.neg_norm.fit(x_neg) # Record the best acc, loss, and params found so far best_loss = torch.inf @@ -266,8 +327,8 @@ def fit( def train_loop_adam( self, - x_pos: Tensor, x_neg: Tensor, + x_pos: Tensor, labels: Optional[Tensor] = None, ) -> float: """Adam train loop, returning the final loss. Modifies params in-place.""" @@ -288,8 +349,8 @@ def train_loop_adam( def train_loop_lbfgs( self, - x_pos: Tensor, x_neg: Tensor, + x_pos: Tensor, labels: Optional[Tensor] = None, ) -> float: """LBFGS train loop, returning the final loss. Modifies params in-place.""" diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 821891cc..5449dd06 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -2,13 +2,12 @@ from dataclasses import dataclass from typing import Optional -from warnings import warn import torch from torch import Tensor, nn, optim -from ..math_util import cov_mean_fused -from ..truncated_eigh import ConvergenceError, truncated_eigh +from ..truncated_eigh import truncated_eigh +from ..utils.math_util import cov_mean_fused from .reporter import Reporter, ReporterConfig @@ -59,6 +58,8 @@ class EigenReporter(Reporter): intracluster_cov: Tensor # invariance contrastive_xcov_M2: Tensor # negative covariance n: Tensor + neg_mean: Tensor + pos_mean: Tensor weight: Tensor def __init__( @@ -68,12 +69,22 @@ def __init__( device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): - super().__init__(in_features, cfg, device=device, dtype=dtype) + super().__init__() + self.config = cfg # Learnable Platt scaling parameters self.bias = nn.Parameter(torch.zeros(cfg.num_heads, device=device, dtype=dtype)) self.scale = nn.Parameter(torch.ones(cfg.num_heads, device=device, dtype=dtype)) + # Running statistics + self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long)) + self.register_buffer( + "neg_mean", torch.zeros(in_features, device=device, dtype=dtype) + ) + self.register_buffer( + "pos_mean", torch.zeros(in_features, device=device, dtype=dtype) + ) + self.register_buffer( "contrastive_xcov_M2", torch.zeros(in_features, in_features, device=device, dtype=dtype), @@ -86,6 +97,8 @@ def __init__( "intracluster_cov", torch.zeros(in_features, in_features, device=device, dtype=dtype), ) + + # Reporter weights self.register_buffer( "weight", torch.zeros(cfg.num_heads, in_features, device=device, dtype=dtype), @@ -170,7 +183,7 @@ def update(self, x_pos: Tensor, x_neg: Tensor) -> None: self.contrastive_xcov_M2.addmm_(neg_delta.mT, pos_delta2) self.contrastive_xcov_M2.addmm_(pos_delta.mT, neg_delta2) - def fit_streaming(self) -> float: + def fit_streaming(self, truncated: bool = False) -> float: """Fit the probe using the current streaming statistics.""" A = ( self.config.var_weight * self.intercluster_cov @@ -178,14 +191,9 @@ def fit_streaming(self) -> float: - self.config.neg_cov_weight * self.contrastive_xcov ) - try: + if truncated: L, Q = truncated_eigh(A, k=self.config.num_heads) - except (ConvergenceError, RuntimeError): - warn( - "Truncated eigendecomposition failed to converge. Falling back on " - "PyTorch's dense eigensolver." - ) - + else: L, Q = torch.linalg.eigh(A) L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :] diff --git a/elk/training/normalizer.py b/elk/training/normalizer.py new file mode 100644 index 00000000..7cb04b97 --- /dev/null +++ b/elk/training/normalizer.py @@ -0,0 +1,63 @@ +from typing import Literal + +import torch +from torch import Tensor, nn + + +class Normalizer(nn.Module): + """Basically `BatchNorm` with a less annoying default axis ordering.""" + + mean: Tensor + std: Tensor | None + + def __init__( + self, + normalized_shape: tuple[int, ...], + *, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + eps: float = 1e-5, + mode: Literal["none", "meanonly", "full"] = "full", + ): + super().__init__() + + self.eps = eps + self.mode = mode + self.normalized_shape = normalized_shape + + self.register_buffer( + "mean", torch.zeros(*normalized_shape, device=device, dtype=dtype) + ) + self.register_buffer( + "std", torch.ones_like(self.mean) if mode == "full" else None + ) + + def forward(self, x: Tensor) -> Tensor: + """Normalize `x` using the stored mean and standard deviation.""" + if self.mode == "none": + return x + elif self.std is None: + return x - self.mean + else: + return (x - self.mean) / self.std + + def fit(self, x: Tensor) -> None: + """Update the stored mean and standard deviation.""" + + # Check shape + num_dims = len(self.normalized_shape) + if x.shape[-num_dims:] != self.normalized_shape: + raise ValueError( + f"Expected trailing sizes {self.normalized_shape} but got " + f"{x.shape[-num_dims:]}" + ) + + if self.mode == "none": + return + + dims = [i for i in range(x.ndim - num_dims)] + if self.std is None: + torch.mean(x, dim=dims, out=self.mean) + else: + variance, self.mean = torch.var_mean(x, dim=dims) + torch.sqrt(variance + self.eps, out=self.std) diff --git a/elk/training/preprocessing.py b/elk/training/preprocessing.py deleted file mode 100644 index 6081dcbb..00000000 --- a/elk/training/preprocessing.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Preprocessing functions for training.""" - -from typing import Literal - -import torch - - -def normalize( - train_hiddens: torch.Tensor, - val_hiddens: torch.Tensor, - method: Literal["legacy", "none", "elementwise", "meanonly"] = "legacy", -) -> tuple[torch.Tensor, torch.Tensor]: - """Normalize the hidden states. - - Normalize the hidden states with the specified method. - The "legacy" method is the same as the original ELK implementation. - The "elementwise" method normalizes each element. - The "meanonly" method normalizes the mean. - - Args: - train_hiddens: The hidden states for the training set. - val_hiddens: The hidden states for the validation set. - method: The normalization method to use. - - Returns: - tuple containing the training and validation hidden states. - """ - if method == "none": - return train_hiddens, val_hiddens - elif method == "legacy": - master = torch.cat([train_hiddens, val_hiddens], dim=0).float() - means = master.mean(dim=0) - - train_hiddens -= means - val_hiddens -= means - - scale = master.shape[-1] ** 0.5 / master.norm(dim=-1).mean() - train_hiddens *= scale - val_hiddens *= scale - else: - means = train_hiddens.float().mean(dim=0) - train_hiddens -= means - val_hiddens -= means - - if method == "elementwise": - scale = 1 / train_hiddens.norm(dim=0, keepdim=True) - elif method == "meanonly": - scale = 1 - else: - raise NotImplementedError(f"Scale method '{method}' is not supported.") - - train_hiddens *= scale - val_hiddens *= scale - - return train_hiddens, val_hiddens diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 9cdfb145..ea607650 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -12,7 +12,6 @@ from torch import Tensor from ..calibration import CalibrationError -from .classifier import Classifier class EvalResult(NamedTuple): @@ -66,89 +65,11 @@ class Reporter(nn.Module, ABC): """ n: Tensor - neg_mean: Tensor - pos_mean: Tensor - - def __init__( - self, - in_features: int, - cfg: ReporterConfig, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): - super().__init__() - - self.config = cfg - self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long)) - self.register_buffer( - "neg_mean", torch.zeros(in_features, device=device, dtype=dtype) - ) - self.register_buffer( - "pos_mean", torch.zeros(in_features, device=device, dtype=dtype) - ) - - @classmethod - def check_separability( - cls, - train_pair: tuple[Tensor, Tensor], - val_pair: tuple[Tensor, Tensor], - ) -> float: - """Measure how linearly separable the pseudo-labels are for a contrast pair. - - Args: - train_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the - contrastive representations. Used for training the classifier. - val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the - contrastive representations. Used for evaluating the classifier. - - Returns: - The AUROC of a linear classifier fit on the pseudo-labels. - """ - x0, x1 = train_pair - val_x0, val_x1 = val_pair - - pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore - pseudo_train_labels = torch.cat( - [ - x0.new_zeros(x0.shape[0]), - x0.new_ones(x0.shape[0]), - ] - ).repeat_interleave( - x0.shape[1] - ) # make num_variants copies of each pseudo-label - pseudo_val_labels = torch.cat( - [ - val_x0.new_zeros(val_x0.shape[0]), - val_x0.new_ones(val_x0.shape[0]), - ] - ).repeat_interleave(val_x0.shape[1]) - - pseudo_clf.fit( - # b v d -> (b v) d - torch.cat([x0, x1]).flatten(0, 1), - pseudo_train_labels, - ) - with torch.no_grad(): - pseudo_preds = pseudo_clf( - # b v d -> (b v) d - torch.cat([val_x0, val_x1]).flatten(0, 1) - ) - return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu())) + config: ReporterConfig def reset_parameters(self): """Reset the parameters of the probe.""" - @torch.no_grad() - def update(self, x_pos: Tensor, x_neg: Tensor) -> None: - """Update the running mean of the positive and negative examples.""" - - x_pos, x_neg = x_pos.flatten(0, -2), x_neg.flatten(0, -2) - self.n += x_pos.shape[0] - - # Update the running means - self.neg_mean += (x_neg.sum(dim=0) - self.neg_mean) / self.n - self.pos_mean += (x_pos.sum(dim=0) - self.pos_mean) / self.n - # TODO: These methods will do something fancier in the future @classmethod def load(cls, path: Path | str): diff --git a/elk/training/supervised.py b/elk/training/supervised.py new file mode 100644 index 00000000..fac7152c --- /dev/null +++ b/elk/training/supervised.py @@ -0,0 +1,50 @@ +import torch +from einops import rearrange, repeat +from sklearn.metrics import accuracy_score, roc_auc_score +from torch import Tensor + +from ..utils import assert_type +from .classifier import Classifier + + +def evaluate_supervised( + lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_labels: Tensor +) -> tuple[float, float]: + X = torch.cat([val_x0, val_x1]) + d = X.shape[-1] + X_val = X.view(-1, d) + with torch.no_grad(): + lr_preds = lr_model(X_val).sigmoid().cpu() + + val_labels_aug = ( + torch.cat([val_labels, 1 - val_labels]).repeat_interleave(val_x0.shape[1]) + ).cpu() + + lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) + lr_auroc = roc_auc_score(val_labels_aug, lr_preds) + + return assert_type(float, lr_auroc), assert_type(float, lr_acc) + + +def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier: + Xs, train_labels = [], [] + + for x0, x1, labels, _ in data.values(): + (_, v, _) = x0.shape + x0 = rearrange(x0, "n v d -> (n v) d") + x1 = rearrange(x1, "n v d -> (n v) d") + + labels = repeat(labels, "n -> (n v)", v=v) + labels = torch.cat([labels, 1 - labels]) + + Xs.append(torch.cat([x0, x1]).squeeze()) + train_labels.append(labels) + + X, train_labels = torch.cat(Xs), torch.cat(train_labels) + lr_model = Classifier(X.shape[-1], device=device) + if cv: + lr_model.fit_cv(X, train_labels) + else: + lr_model.fit(X, train_labels) + + return lr_model diff --git a/elk/training/train.py b/elk/training/train.py index 9be8c589..403f8ce9 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,25 +1,23 @@ """Main training loop.""" -import warnings from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Literal, Optional +from typing import Callable, Literal import pandas as pd import torch from simple_parsing import Serializable, field, subgroups from sklearn.metrics import accuracy_score, roc_auc_score -from torch import Tensor from ..extraction.extraction import Extract from ..run import Run -from ..training.baseline import evaluate_baseline, save_baseline, train_baseline +from ..training.supervised import evaluate_supervised, train_supervised from ..utils import select_usable_devices from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig -from .reporter import OptimConfig, Reporter, ReporterConfig +from .reporter import OptimConfig, ReporterConfig @dataclass @@ -34,8 +32,9 @@ class Elicit(Serializable): "use all available GPUs". normalization: The normalization method to use. Defaults to "meanonly". See `elk.training.preprocessing.normalize()` for details. - skip_baseline: Whether to skip training the baseline classifier. Defaults to - False. + supervised: Whether to train a supervised classifier, and if so, whether to + use cross-validation. Defaults to "single", which means to train a single + classifier on the training data. "cv" means to use cross-validation. debug: When in debug mode, a useful log file is saved to the memorably-named output directory. Defaults to False. """ @@ -46,13 +45,12 @@ class Elicit(Serializable): ) optim: OptimConfig = field(default_factory=OptimConfig) - num_gpus: int = -1 - normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" - skip_baseline: bool = False concatenated_layer_offset: int = 0 - # if nonzero, appends the hidden states of layer concatenated_layer_offset before debug: bool = False - out_dir: Optional[Path] = None + min_gpu_mem: int | None = None + num_gpus: int = -1 + out_dir: Path | None = None + supervised: Literal["none", "single", "cv"] = "single" def execute(self): train_run = Train(cfg=self, out_dir=self.out_dir) @@ -78,88 +76,102 @@ def train_reporter( layer: int, devices: list[str], world_size: int = 1, - ) -> pd.Series: + ) -> pd.DataFrame: """Train a single reporter on a single layer.""" self.make_reproducible(seed=self.cfg.net.seed + layer) - device = self.get_device(devices, world_size) - x0, x1, val_x0, val_x1, train_gt, val_gt, val_lm_preds = self.prepare_data( - device, layer - ) - pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1) - - if isinstance(self.cfg.net, CcsReporterConfig): - reporter = CcsReporter(x0.shape[-1], self.cfg.net, device=device) - elif isinstance(self.cfg.net, EigenReporterConfig): - reporter = EigenReporter(x0.shape[-1], self.cfg.net, device=device) - else: - raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") - - train_loss = reporter.fit(x0, x1, train_gt) - val_result = reporter.score( - val_gt, - val_x0, - val_x1, - ) + train_dict = self.prepare_data(device, layer, "train") + val_dict = self.prepare_data(device, layer, "val") + # Can't figure out a way to make this line less ugly + hidden_size = next(iter(train_dict.values()))[0].shape[-1] reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - if val_lm_preds is not None: - val_gt_cpu = val_gt.repeat_interleave(val_lm_preds.shape[1]).float().cpu() - val_lm_auroc = float(roc_auc_score(val_gt_cpu, val_lm_preds.flatten())) - val_lm_acc = float(accuracy_score(val_gt_cpu, val_lm_preds.flatten() > 0.5)) - else: - val_lm_auroc = None - val_lm_acc = None - - row = pd.Series( - { - "layer": layer, - "pseudo_auroc": pseudo_auroc, - "train_loss": train_loss, - **val_result._asdict(), - "lm_auroc": val_lm_auroc, - "lm_acc": val_lm_acc, - } - ) - if not self.cfg.skip_baseline: - lr_model = train_baseline(x0, x1, train_gt, device=device) + if isinstance(self.cfg.net, CcsReporterConfig): + assert len(train_dict) == 1, "CCS only supports single-task training" - lr_auroc, lr_acc = evaluate_baseline(lr_model, val_x0, val_x1, val_gt) + reporter = CcsReporter(hidden_size, self.cfg.net, device=device) + (x0, x1, labels, _) = next(iter(train_dict.values())) + train_loss = reporter.fit(x0, x1, labels) - row["lr_auroc"] = lr_auroc - row["lr_acc"] = lr_acc - save_baseline(lr_dir, layer, lr_model) + (val_x0, val_x1, val_gt, _) = next(iter(val_dict.values())) + pseudo_auroc = reporter.check_separability( + train_pair=(x0, x1), val_pair=(val_x0, val_x1) + ) + elif isinstance(self.cfg.net, EigenReporterConfig): + # To enable training on multiple tasks with different numbers of variants, + # we update the statistics in a streaming fashion and then fit + reporter = EigenReporter(hidden_size, self.cfg.net, device=device) + for ds_name, (x0, x1, labels, _) in train_dict.items(): + reporter.update(x0, x1) + + pseudo_auroc = None + train_loss = reporter.fit_streaming() + else: + raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") + + # Save reporter checkpoint to disk with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: torch.save(reporter, file) - return row - - def get_pseudo_auroc( - self, layer: int, x0: Tensor, x1: Tensor, val_x0: Tensor, val_x1: Tensor - ): - """Check the separability of the pseudo-labels at a given layer.""" + # Fit supervised logistic regression model + if self.cfg.supervised != "none": + lr_model = train_supervised( + train_dict, device=device, cv=self.cfg.supervised == "cv" + ) + with open(lr_dir / f"layer_{layer}.pt", "wb") as file: + torch.save(lr_model, file) + else: + lr_model = None + + row_buf = [] + for ds_name, (val_x0, val_x1, val_gt, val_lm_preds) in val_dict.items(): + val_result = reporter.score( + val_gt, + val_x0, + val_x1, + ) - with torch.no_grad(): - pseudo_auroc = Reporter.check_separability( - train_pair=(x0, x1), val_pair=(val_x0, val_x1) + if val_lm_preds is not None: + val_gt_cpu = ( + val_gt.repeat_interleave(val_lm_preds.shape[1]).float().cpu() + ) + val_lm_auroc = float(roc_auc_score(val_gt_cpu, val_lm_preds.flatten())) + val_lm_acc = float( + accuracy_score(val_gt_cpu, val_lm_preds.flatten() > 0.5) + ) + else: + val_lm_auroc = None + val_lm_acc = None + + row = pd.Series( + { + "dataset": ds_name, + "layer": layer, + "pseudo_auroc": pseudo_auroc, + "train_loss": train_loss, + **val_result._asdict(), + "lm_auroc": val_lm_auroc, + "lm_acc": val_lm_acc, + } ) - if pseudo_auroc > 0.6: - warnings.warn( - f"The pseudo-labels at layer {layer} are linearly separable with " - f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the " - f"algorithm will not converge to a good solution." + + if lr_model is not None: + row["lr_auroc"], row["lr_acc"] = evaluate_supervised( + lr_model, val_x0, val_x1, val_gt ) - return pseudo_auroc + row_buf.append(row) + + return pd.DataFrame(row_buf) def train(self): """Train a reporter on each layer of the network.""" devices = select_usable_devices(self.cfg.num_gpus) num_devices = len(devices) - func: Callable[[int], pd.Series] = partial( + func: Callable[[int], pd.DataFrame] = partial( self.train_reporter, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index 1400a98d..13656933 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -2,27 +2,37 @@ binarize, convert_span, get_columns_all_equal, + get_dataset_name, + get_layers, + has_multiple_configs, infer_label_column, infer_num_classes, select_train_val_splits, ) from .gpu_utils import select_usable_devices from .hf_utils import instantiate_model, is_autoregressive +from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained from .tree_utils import pytree_map from .typing import assert_type, float32_to_int16, int16_to_float32 __all__ = [ + "assert_type", + "batch_cov", "binarize", "convert_span", + "cov_mean_fused", + "float32_to_int16", "get_columns_all_equal", + "get_dataset_name", + "get_layers", + "has_multiple_configs", "infer_label_column", "infer_num_classes", "instantiate_model", - "is_autoregressive", - "float32_to_int16", "int16_to_float32", + "is_autoregressive", + "pytree_map", "select_train_val_splits", "select_usable_devices", - "pytree_map", - "assert_type", + "stochastic_round_constrained", ] diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index a98a7aae..0fbd7353 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -1,8 +1,9 @@ import copy from bisect import bisect_left, bisect_right +from functools import cache from operator import itemgetter from random import Random -from typing import Any, Iterable, List +from typing import Any, Iterable from datasets import ( ClassLabel, @@ -10,6 +11,7 @@ Features, Split, Value, + get_dataset_config_names, ) from ..promptsource.templates import Template @@ -44,6 +46,30 @@ def get_columns_all_equal(dataset: DatasetDict) -> list[str]: return pivot +def get_dataset_name(dataset: DatasetDict) -> str: + """Get the name of a `DatasetDict`.""" + builder_name, *rest = [ds.builder_name for ds in dataset.values()] + if not all(name == builder_name for name in rest): + raise ValueError( + f"All splits must have the same name; got {[builder_name, *rest]}" + ) + + config_name, *rest = [ds.config_name for ds in dataset.values()] + if not all(name == config_name for name in rest): + raise ValueError( + f"All splits must have the same config name; got {[config_name, *rest]}" + ) + + include_config = config_name and has_multiple_configs(builder_name) + return builder_name + " " + config_name if include_config else builder_name + + +@cache +def has_multiple_configs(ds_name: str) -> bool: + """Return whether a dataset has multiple configs.""" + return len(get_dataset_config_names(ds_name)) > 1 + + def select_train_val_splits(raw_splits: Iterable[str]) -> tuple[str, str]: """Return splits to use for train and validation, given an Iterable of splits.""" @@ -101,7 +127,7 @@ def infer_num_classes(label_feature: Any) -> int: ) -def get_layers(ds: DatasetDict) -> List[int]: +def get_layers(ds: DatasetDict) -> list[int]: """Get a list of indices of hidden layers given a `DatasetDict`.""" layers = [ int(feat[len("hidden_") :]) diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 4c3ab331..4e97b6ee 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -1,8 +1,6 @@ import transformers from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel -from .typing import assert_type - # Ordered by preference _AUTOREGRESSIVE_SUFFIXES = [ # Encoder-decoder models @@ -16,7 +14,9 @@ def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel: """Instantiate a model string with the appropriate `Auto` class.""" model_cfg = AutoConfig.from_pretrained(model_str) - archs = assert_type(list, model_cfg.architectures) + archs = model_cfg.architectures + if not isinstance(archs, list): + return AutoModel.from_pretrained(model_str, **kwargs) for suffix in _AUTOREGRESSIVE_SUFFIXES: # Check if any of the architectures in the config end with the suffix. @@ -31,7 +31,10 @@ def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel: def is_autoregressive(model_cfg: PretrainedConfig) -> bool: """Check if a model config is autoregressive.""" - archs = assert_type(list, model_cfg.architectures) + archs = model_cfg.architectures + if not isinstance(archs, list): + return False + return any( arch_str.endswith(suffix) for arch_str in archs diff --git a/elk/math_util.py b/elk/utils/math_util.py similarity index 100% rename from elk/math_util.py rename to elk/utils/math_util.py diff --git a/pyproject.toml b/pyproject.toml index 16edd58e..f688416d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ license = {text = "MIT License"} dependencies = [ # Added distributed.split_dataset_by_node for IterableDatasets "datasets>=2.9.0", + "einops", # Introduced numpy.typing module "numpy>=1.20.0", # This version is old, but it's needed for certain HF tokenizers to work. diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index 58dd6c13..bf00377d 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -1,7 +1,7 @@ import torch -from elk.math_util import batch_cov, cov_mean_fused from elk.training import EigenReporter, EigenReporterConfig +from elk.utils import batch_cov, cov_mean_fused def test_eigen_reporter(): diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py index a5d238fd..2e03c379 100644 --- a/tests/test_load_prompts.py +++ b/tests/test_load_prompts.py @@ -1,4 +1,4 @@ -from itertools import cycle, islice +from itertools import islice from typing import Literal import pytest @@ -10,26 +10,22 @@ @pytest.mark.filterwarnings("ignore:Unable to find a decoding function") def test_load_prompts(): def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): - prompt_ds = load_prompts( - *cfg.datasets, - split_type=split_type, - ) - prompters = [] - - for ds in cfg.datasets: - ds_name, _, config_name = ds.partition(" ") + for cfg in cfg.explode(): + ds_string = cfg.datasets[0] + prompt_ds = load_prompts(ds_string, split_type=split_type) + + ds_name, _, config_name = ds_string.partition(" ") prompter = DatasetTemplates(ds_name, config_name or None) - prompters.append(prompter) - limit = cfg.max_examples[0 if split_type == "train" else 1] - for prompter, record in zip(cycle(prompters), islice(prompt_ds, limit)): - true_template_names = prompter.all_template_names - returned_template_names = record["template_names"] + limit = cfg.max_examples[0 if split_type == "train" else 1] + for record in islice(prompt_ds, limit): + true_template_names = prompter.all_template_names + returned_template_names = record["template_names"] - # check for using the same templates - assert set(true_template_names) == set(returned_template_names) - # check for them being in the same order - assert true_template_names == true_template_names + # check for using the same templates + assert set(true_template_names) == set(returned_template_names) + # check for them being in the same order + assert true_template_names == true_template_names # the case where the dataset has 2 classes # this dataset is small diff --git a/tests/test_math.py b/tests/test_math.py index ee81914e..34984d8f 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -6,7 +6,7 @@ from hypothesis import given from hypothesis import strategies as st -from elk.math_util import batch_cov, cov_mean_fused, stochastic_round_constrained +from elk.utils import batch_cov, cov_mean_fused, stochastic_round_constrained def test_cov_mean_fused(): diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index 5bbcfef3..e61568ba 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -14,10 +14,10 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): data=Extract( model=model_path, prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), - min_gpu_mem=min_mem, # run on all layers, tiny-gpt only has 2 layers ), num_gpus=2, + min_gpu_mem=min_mem, net=CcsReporterConfig(), out_dir=tmp_path, ) @@ -25,7 +25,13 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): # get the files in the tmp_path files: list[Path] = list(tmp_path.iterdir()) created_file_names = {file.name for file in files} - expected_files = ["cfg.yaml", "metadata.yaml", "lr_models", "reporters", "eval.csv"] + expected_files = [ + "cfg.yaml", + "fingerprints.yaml", + "lr_models", + "reporters", + "eval.csv", + ] for file in expected_files: assert file in created_file_names @@ -38,10 +44,10 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): data=Extract( model=model_path, prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), - min_gpu_mem=min_mem, # run on all layers, tiny-gpt only has 2 layers ), num_gpus=2, + min_gpu_mem=min_mem, net=EigenReporterConfig(), out_dir=tmp_path, ) @@ -49,6 +55,12 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): # get the files in the tmp_path files: list[Path] = list(tmp_path.iterdir()) created_file_names = {file.name for file in files} - expected_files = ["cfg.yaml", "metadata.yaml", "lr_models", "reporters", "eval.csv"] + expected_files = [ + "cfg.yaml", + "fingerprints.yaml", + "lr_models", + "reporters", + "eval.csv", + ] for file in expected_files: assert file in created_file_names