From d292c7c96080e7960a4b95151e5206a8284e29b1 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 4 Apr 2023 11:28:03 +0000 Subject: [PATCH 01/43] LM output evaluation for autoregressive models --- elk/evaluation/evaluate.py | 20 ++------- elk/extraction/balanced_sampler.py | 2 - elk/extraction/extraction.py | 68 +++++++++++++++++++++--------- elk/run.py | 5 ++- elk/training/train.py | 16 ++++--- elk/training/train_log.py | 13 ++++-- elk/utils/__init__.py | 2 + elk/utils/data_utils.py | 24 +++++++++-- elk/utils/hf_utils.py | 32 ++++++++++++++ tests/test_convert_span.py | 42 ++++++++++++++++++ 10 files changed, 173 insertions(+), 51 deletions(-) create mode 100644 elk/utils/hf_utils.py create mode 100644 tests/test_convert_span.py diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index c26aabf8..ec728644 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,30 +1,18 @@ -import csv -import os from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Literal, Optional, cast +from typing import Callable, Literal, Optional import torch -import torch.multiprocessing as mp from simple_parsing.helpers import Serializable, field -from torch import Tensor -from tqdm.auto import tqdm -from datasets import DatasetDict from elk.evaluation.evaluate_log import EvalLog from elk.extraction.extraction import Extract from elk.run import Run from elk.training import Reporter -from ..files import elk_reporter_dir, memorably_named_dir -from ..training.preprocessing import normalize -from ..utils import ( - assert_type, - int16_to_float32, - select_train_val_splits, - select_usable_devices, -) +from ..files import elk_reporter_dir +from ..utils import select_usable_devices @dataclass @@ -71,7 +59,7 @@ def evaluate_reporter( """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) - _, _, test_x0, test_x1, _, test_labels = self.prepare_data( + _, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data( device, layer, ) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 7828eab9..e472684b 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -2,7 +2,6 @@ from ..utils import infer_label_column from ..utils.typing import assert_type from collections import deque -from dataclasses import dataclass from datasets import IterableDataset, Features from itertools import cycle from random import Random @@ -13,7 +12,6 @@ class BalancedSampler(TorchIterableDataset): """ Approximately balances a binary classification dataset in a streaming fashion. - Written mostly by GPT-4. Args: dataset (IterableDataset): The HuggingFace IterableDataset to balance. diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 4cff0de4..ac3c4a4a 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,15 +1,7 @@ """Functions for extracting the hidden states of a model.""" -import logging -import os from dataclasses import InitVar, dataclass -from itertools import islice -from typing import Iterable, Literal, Optional, Union - -import torch -from simple_parsing import Serializable, field -from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel - from datasets import ( + Array2D, Array3D, ClassLabel, DatasetDict, @@ -20,11 +12,19 @@ Value, get_dataset_config_info, ) -from elk.utils.typing import float32_to_int16 +from itertools import islice +from simple_parsing import Serializable, field +from transformers import AutoConfig, AutoTokenizer, PreTrainedModel +from typing import Iterable, Literal, Optional, Union +import logging +import os +import torch from ..utils import ( assert_type, - infer_label_column, + convert_span, + float32_to_int16, + get_model_class, select_train_val_splits, select_usable_devices, ) @@ -101,10 +101,12 @@ def extract_hiddens( world_size=world_size, ) # this dataset is already sharded, but hasn't been truncated to max_examples - # AutoModel should do the right thing here in nearly all cases. We don't actually - # care what head the model has, since we are just extracting hidden states. - model = AutoModel.from_pretrained( - cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32 + model_cls = get_model_class(cfg.model) + model = assert_type( + PreTrainedModel, + model_cls.from_pretrained( + cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32 + ), ).to(device) # TODO: Maybe also make this configurable? # We want to make sure the answer is never truncated @@ -126,7 +128,6 @@ def extract_hiddens( # Iterating over questions layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) - # print(f"Using {prompt_ds} variants for each dataset") global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1] # break `max_examples` among the processes roughly equally @@ -135,8 +136,6 @@ def extract_hiddens( if rank == world_size - 1: max_examples += global_max_examples % world_size - print(f"Extracting {max_examples} examples from {prompt_ds} on {device}") - for example in islice(BalancedSampler(prompt_ds), max_examples): num_variants = len(example["prompts"]) hidden_dict = { @@ -149,6 +148,12 @@ def extract_hiddens( ) for layer_idx in layer_indices } + model_preds = torch.empty( + num_variants, + 2, # contrast pair + device=device, + dtype=torch.float32, + ) text_inputs = [] # Iterate over variants @@ -156,17 +161,34 @@ def extract_hiddens( variant_inputs = [] # Iterate over answers - for j in range(2): - text = record[j]["text"] + for j, choice in enumerate(record): + text = choice["text"] variant_inputs.append(text) inputs = tokenizer( text, + return_offsets_mapping=True, return_tensors="pt", truncation=True, ).to(device) + + # The offset_mapping is a sorted list of (start, end) tuples. We locate + # the start of the answer in the tokenized sequence with binary search. + offsets = inputs.pop("offset_mapping").squeeze().tolist() + outputs = model(**inputs, output_hidden_states=True) + # TODO: Do something smarter than "rindex" here. Really we'd like to + # get the span of the answer directly from Jinja, but that doesn't seem + # to be supported. The current approach may fail for complex templates. + answer_start = text.rindex(choice["answer"]) + start, end = convert_span( + offsets, (answer_start, answer_start + len(choice["answer"])) + ) + log_p = outputs.logits[..., start - 1 : end - 1, :].log_softmax(dim=-1) + tokens = inputs.input_ids[..., start:end, None] + model_preds[i, j] = log_p.gather(-1, tokens).sum() + hiddens = ( outputs.get("decoder_hidden_states") or outputs["hidden_states"] ) @@ -193,6 +215,8 @@ def extract_hiddens( yield dict( label=example["label"], + # We only need the probability of the positive example since this is binary + model_preds=model_preds.softmax(dim=-1)[..., 1], variant_ids=example["template_names"], text_inputs=text_inputs, **hidden_dict, @@ -245,6 +269,10 @@ def get_splits() -> SplitDict: length=num_variants, ), "label": ClassLabel(names=["neg", "pos"]), + "model_preds": Sequence( + Value(dtype="float32"), + length=num_variants, + ), "text_inputs": Sequence( Sequence( Value(dtype="string"), diff --git a/elk/run.py b/elk/run.py index 7b25ccd7..173750fb 100644 --- a/elk/run.py +++ b/elk/run.py @@ -88,7 +88,10 @@ def prepare_data( x0, x1 = train_h.unbind(dim=-2) val_x0, val_x1 = val_h.unbind(dim=-2) - return x0, x1, val_x0, val_x1, train_labels, val_labels + with self.dataset.formatted_as("numpy"): + val_lm_preds = assert_type(np.ndarray, val["model_preds"]) + + return x0, x1, val_x0, val_x1, train_labels, val_labels, val_lm_preds def concatenate(self, layers): """Concatenate hidden states from a previous layer.""" diff --git a/elk/training/train.py b/elk/training/train.py index 81a9bde6..ad033757 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -124,7 +124,7 @@ def train_reporter( device = self.get_device(devices, world_size) - x0, x1, val_x0, val_x1, train_labels, val_labels = self.prepare_data( + x0, x1, val_x0, val_x1, train_gt, val_gt, val_lm_preds = self.prepare_data( device, layer ) pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1) @@ -136,19 +136,23 @@ def train_reporter( else: raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") - train_loss = reporter.fit(x0, x1, train_labels) + train_loss = reporter.fit(x0, x1, train_gt) val_result = reporter.score( - val_labels, + val_gt, val_x0, val_x1, ) reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - stats: ElicitLog = ElicitLog( + val_gt_cpu = val_gt.repeat_interleave(val_lm_preds.shape[1]).float().cpu() + + stats = ElicitLog( layer=layer, pseudo_auroc=pseudo_auroc, train_loss=train_loss, eval_result=val_result, + lm_auroc=float(roc_auc_score(val_gt_cpu, val_lm_preds.flatten())), + lm_acc=float(accuracy_score(val_gt_cpu, val_lm_preds.flatten() > 0.5)), ) if not self.cfg.skip_baseline: @@ -157,8 +161,8 @@ def train_reporter( x1, val_x0, val_x1, - train_labels, - val_labels, + train_gt, + val_gt, device, ) stats.lr_auroc = lr_auroc diff --git a/elk/training/train_log.py b/elk/training/train_log.py index 37e5a392..33ecd72b 100644 --- a/elk/training/train_log.py +++ b/elk/training/train_log.py @@ -1,17 +1,20 @@ +from .reporter import EvalResult from dataclasses import dataclass from typing import Optional -from elk.training.reporter import EvalResult - @dataclass class ElicitLog: """The result of running elicit on a layer of a dataset""" layer: int + pseudo_auroc: float train_loss: float eval_result: EvalResult - pseudo_auroc: float + + lm_auroc: float + lm_acc: float + # Only available if reporting baseline lr_auroc: Optional[float] = None # Only available if reporting baseline @@ -28,6 +31,8 @@ def csv_columns(skip_baseline: bool) -> list[str]: "cal_acc", "auroc", "ece", + "lm_auroc", + "lm_acc", ] if not skip_baseline: cols += ["lr_auroc", "lr_acc"] @@ -43,6 +48,8 @@ def to_csv_line(self, skip_baseline: bool) -> list[str]: self.eval_result.cal_acc, self.eval_result.auroc, self.eval_result.ece, + self.lm_auroc, + self.lm_acc, ] if not skip_baseline: items += [self.lr_auroc, self.lr_acc] diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index b2afab9e..5dfda1aa 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -1,5 +1,6 @@ from .data_utils import ( binarize, + convert_span, get_columns_all_equal, infer_label_column, infer_num_classes, @@ -7,5 +8,6 @@ ) from .gpu_utils import select_usable_devices +from .hf_utils import get_model_class from .tree_utils import pytree_map from .typing import assert_type, float32_to_int16, int16_to_float32 diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 198aaf82..848f0d4e 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -1,5 +1,6 @@ from .typing import assert_type from ..promptsource.templates import Template +from bisect import bisect_left, bisect_right from datasets import ( ClassLabel, DatasetDict, @@ -7,13 +8,30 @@ Split, Value, ) +from operator import itemgetter from random import Random -import torch -from typing import Iterable, Optional, List, Any -import numpy as np +from typing import Iterable, List, Any import copy +def convert_span( + offsets: list[tuple[int, int]], span: tuple[int, int] +) -> tuple[int, int]: + """Convert `span` from string coordinates to token coordinates. + + Args: + offsets: The offset mapping of the target tokenization. + span: The span to convert. + + Returns: + (start, end): The converted span. + """ + start, end = span + start = bisect_right(offsets, start, key=itemgetter(1)) + end = bisect_left(offsets, end, lo=start, key=itemgetter(0)) + return start, end + + def get_columns_all_equal(dataset: DatasetDict) -> list[str]: """Get columns of a `DatasetDict`, asserting all splits have the same columns.""" pivot, *rest = dataset.column_names.values() diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py new file mode 100644 index 00000000..fc3ef60b --- /dev/null +++ b/elk/utils/hf_utils.py @@ -0,0 +1,32 @@ +from .typing import assert_type +from transformers import AutoConfig, PreTrainedModel +from typing import Type +import transformers + + +def get_model_class(model_str: str) -> Type[PreTrainedModel]: + """Get the appropriate model class for a model string.""" + model_cfg = AutoConfig.from_pretrained(model_str) + archs = assert_type(list, model_cfg.architectures) + + # Ordered by preference + suffixes = [ + # Fine-tuned for classification + "SequenceClassification", + # Encoder-decoder models + "ConditionalGeneration", + # Autoregressive models + "CausalLM", + "LMHeadModel", + ] + + for suffix in suffixes: + # Check if any of the architectures in the config end with the suffix. + # If so, return the corresponding model class. + for arch_str in archs: + if arch_str.endswith(suffix): + return getattr(transformers, arch_str) + + raise ValueError( + f"'{model_str}' does not have any supported architectures: {archs}" + ) diff --git a/tests/test_convert_span.py b/tests/test_convert_span.py new file mode 100644 index 00000000..10e8c583 --- /dev/null +++ b/tests/test_convert_span.py @@ -0,0 +1,42 @@ +from elk.utils import convert_span +from hypothesis import given, strategies as st +from transformers import AutoTokenizer +import pytest + + +# Define a fixture with session scope that initializes the tokenizer +@pytest.fixture(scope="session") +def tokenizer(): + yield AutoTokenizer.from_pretrained("gpt2") + + +# Hypothesis will generate really bizarre Unicode strings for us +@st.composite +def string_and_span(draw) -> tuple[str, tuple[int, int]]: + """Generate a non-empty string and a non-empty span within that string.""" + text = draw(st.text(min_size=1)) + start = draw(st.integers(min_value=0, max_value=len(text) - 1)) + end = draw(st.integers(min_value=start + 1, max_value=len(text))) + return text, (start, end) + + +@given(string_and_span()) +def test_convert_span(tokenizer, text_and_span: tuple[str, tuple[int, int]]): + text, span = text_and_span + + tokenizer_output = tokenizer(text, return_offsets_mapping=True) + tokens = tokenizer_output["input_ids"] + + # Convert the span in string coordinates to a span in token coordinates + token_start, token_end = convert_span(tokenizer_output["offset_mapping"], span) + assert token_start < token_end + + string_start, string_end = span + substring = text[string_start:string_end] + token_subsequence = tokens[token_start:token_end] + + # Decode the subsequence of tokens back to a string + decoded_string = tokenizer.decode(token_subsequence) + + # Assert that the original substring is fully contained within the decoded string + assert substring in decoded_string From 7ed5ccd36d1ae31fc317eb1eb5b28d482ef308d0 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Tue, 4 Apr 2023 23:11:28 +0000 Subject: [PATCH 02/43] move to own baseline file --- elk/training/baseline.py | 56 ++++++++++++++++++++++++++++++++++++ elk/training/train.py | 62 ++++++---------------------------------- 2 files changed, 65 insertions(+), 53 deletions(-) create mode 100644 elk/training/baseline.py diff --git a/elk/training/baseline.py b/elk/training/baseline.py new file mode 100644 index 00000000..23c0b642 --- /dev/null +++ b/elk/training/baseline.py @@ -0,0 +1,56 @@ +import pickle +from pathlib import Path +from typing import Tuple + +import torch +from sklearn.metrics import accuracy_score, roc_auc_score +from torch import Tensor + +from ..utils.typing import assert_type +from .classifier import Classifier + +# TODO: Create class for baseline? + +def evaluate_baseline(lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_labels: Tensor) -> Tuple[float, float]: + X = torch.cat([val_x0, val_x1]) + d = X.shape[-1] + X_val = X.view(-1, d) + with torch.no_grad(): + lr_preds = lr_model(X_val).sigmoid().cpu() + + val_labels_aug = ( + torch.cat([val_labels, 1 - val_labels]).repeat_interleave(val_x0.shape[1]) + ).cpu() + + lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) + lr_auroc = roc_auc_score(val_labels_aug, lr_preds) + + return assert_type(float, lr_auroc), assert_type(float, lr_acc) + +def train_baseline( + x0: Tensor, + x1: Tensor, + train_labels: Tensor, + device: str, +) -> Classifier: + # repeat_interleave makes `num_variants` copies of each label, all within a + # single dimension of size `num_variants * 2 * n`, such that the labels align + # with X.view(-1, X.shape[-1]) + train_labels_aug = torch.cat( + [train_labels, 1 - train_labels] + ).repeat_interleave(x0.shape[1]) + + X = torch.cat([x0, x1]).squeeze() + d = X.shape[-1] + lr_model = Classifier(d, device=device) + lr_model.fit_cv(X.view(-1, d), train_labels_aug) + + return lr_model + +def save_baseline(lr_dir: Path, layer: int, lr_model: Classifier): + with open(lr_dir / f"layer_{layer}.pt", "wb") as file: + pickle.dump(lr_model, file) + +def load_baseline(lr_dir: Path, layer: int) -> Classifier: + with open(lr_dir / f"layer_{layer}.pt", "rb") as file: + return pickle.load(file) \ No newline at end of file diff --git a/elk/training/train.py b/elk/training/train.py index 81a9bde6..6d58eb8e 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,27 +1,26 @@ """Main training loop.""" -import pickle import warnings from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Literal, Optional, Callable +from typing import Callable, Literal, Optional import torch from simple_parsing import Serializable, field, subgroups -from sklearn.metrics import accuracy_score, roc_auc_score from torch import Tensor from elk.extraction.extraction import Extract from elk.run import Run +from elk.training.baseline import (evaluate_baseline, save_baseline, + train_baseline) from elk.utils.typing import assert_type +from ..utils import select_usable_devices from .ccs_reporter import CcsReporter, CcsReporterConfig -from .classifier import Classifier from .eigen_reporter import EigenReporter, EigenReporterConfig from .reporter import OptimConfig, Reporter, ReporterConfig from .train_log import ElicitLog -from ..utils import select_usable_devices @dataclass @@ -65,40 +64,6 @@ def execute(self): class Train(Run): cfg: Elicit - def train_baseline( - self, - x0: Tensor, - x1: Tensor, - val_x0: Tensor, - val_x1: Tensor, - train_labels: Tensor, - val_labels: Tensor, - device: str, - ): - # repeat_interleave makes `num_variants` copies of each label, all within a - # single dimension of size `num_variants * 2 * n`, such that the labels align - # with X.view(-1, X.shape[-1]) - train_labels_aug = torch.cat( - [train_labels, 1 - train_labels] - ).repeat_interleave(x0.shape[1]) - val_labels_aug = ( - torch.cat([val_labels, 1 - val_labels]).repeat_interleave(x0.shape[1]) - ).cpu() - - X = torch.cat([x0, x1]).squeeze() - d = X.shape[-1] - lr_model = Classifier(d, device=device) - lr_model.fit_cv(X.view(-1, d), train_labels_aug) - - X_val = torch.cat([val_x0, val_x1]).view(-1, d) - with torch.no_grad(): - lr_preds = lr_model(X_val).sigmoid().cpu() - - lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) - lr_auroc = roc_auc_score(val_labels_aug, lr_preds) - - return lr_model, lr_auroc, lr_acc - def create_models_dir(self, out_dir: Path): lr_dir = None lr_dir = out_dir / "lr_models" @@ -109,10 +74,6 @@ def create_models_dir(self, out_dir: Path): return reporter_dir, lr_dir - def save_baseline(self, lr_dir: Path, layer: int, lr_model: Classifier): - with open(lr_dir / f"layer_{layer}.pt", "wb") as file: - pickle.dump(lr_model, file) - def train_reporter( self, layer: int, @@ -152,18 +113,13 @@ def train_reporter( ) if not self.cfg.skip_baseline: - lr_model, lr_auroc, lr_acc = self.train_baseline( - x0, - x1, - val_x0, - val_x1, - train_labels, - val_labels, - device, - ) + lr_model = train_baseline(x0, x1, train_labels, device=device) + + lr_auroc, lr_acc = evaluate_baseline(lr_model, val_x0, val_x1, val_labels) + stats.lr_auroc = lr_auroc stats.lr_acc = lr_acc - self.save_baseline(lr_dir, layer, lr_model) + save_baseline(lr_dir, layer, lr_model) with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: torch.save(reporter, file) From ba1d3b2942cd9b05143f29353e6233d0b6cec754 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Tue, 4 Apr 2023 23:13:58 +0000 Subject: [PATCH 03/43] cleanup --- elk/run.py | 12 +++--------- elk/training/baseline.py | 16 +++++++++++----- elk/training/train.py | 10 ++++------ 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/elk/run.py b/elk/run.py index 7b25ccd7..ea6606ec 100644 --- a/elk/run.py +++ b/elk/run.py @@ -3,25 +3,19 @@ from abc import ABC from dataclasses import dataclass, field from pathlib import Path -from typing import ( - TYPE_CHECKING, - Optional, - Union, - Callable, - Iterator, -) +from typing import TYPE_CHECKING, Callable, Iterator, Optional, Union import numpy as np import torch import torch.multiprocessing as mp -from datasets import DatasetDict from torch import Tensor from tqdm import tqdm +from datasets import DatasetDict from elk.extraction.extraction import extract from elk.files import create_output_directory, save_config, save_meta from elk.training.preprocessing import normalize -from elk.utils.csv import write_iterator_to_file, Log +from elk.utils.csv import Log, write_iterator_to_file from elk.utils.data_utils import get_layers, select_train_val_splits from elk.utils.typing import assert_type, int16_to_float32 diff --git a/elk/training/baseline.py b/elk/training/baseline.py index 23c0b642..2c9542a6 100644 --- a/elk/training/baseline.py +++ b/elk/training/baseline.py @@ -11,7 +11,10 @@ # TODO: Create class for baseline? -def evaluate_baseline(lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_labels: Tensor) -> Tuple[float, float]: + +def evaluate_baseline( + lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_labels: Tensor +) -> Tuple[float, float]: X = torch.cat([val_x0, val_x1]) d = X.shape[-1] X_val = X.view(-1, d) @@ -27,6 +30,7 @@ def evaluate_baseline(lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_ return assert_type(float, lr_auroc), assert_type(float, lr_acc) + def train_baseline( x0: Tensor, x1: Tensor, @@ -36,9 +40,9 @@ def train_baseline( # repeat_interleave makes `num_variants` copies of each label, all within a # single dimension of size `num_variants * 2 * n`, such that the labels align # with X.view(-1, X.shape[-1]) - train_labels_aug = torch.cat( - [train_labels, 1 - train_labels] - ).repeat_interleave(x0.shape[1]) + train_labels_aug = torch.cat([train_labels, 1 - train_labels]).repeat_interleave( + x0.shape[1] + ) X = torch.cat([x0, x1]).squeeze() d = X.shape[-1] @@ -47,10 +51,12 @@ def train_baseline( return lr_model + def save_baseline(lr_dir: Path, layer: int, lr_model: Classifier): with open(lr_dir / f"layer_{layer}.pt", "wb") as file: pickle.dump(lr_model, file) + def load_baseline(lr_dir: Path, layer: int) -> Classifier: with open(lr_dir / f"layer_{layer}.pt", "rb") as file: - return pickle.load(file) \ No newline at end of file + return pickle.load(file) diff --git a/elk/training/train.py b/elk/training/train.py index 6d58eb8e..a6ecafdc 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -10,13 +10,11 @@ from simple_parsing import Serializable, field, subgroups from torch import Tensor -from elk.extraction.extraction import Extract -from elk.run import Run -from elk.training.baseline import (evaluate_baseline, save_baseline, - train_baseline) -from elk.utils.typing import assert_type - +from ..extraction.extraction import Extract +from ..run import Run +from ..training.baseline import evaluate_baseline, save_baseline, train_baseline from ..utils import select_usable_devices +from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig from .reporter import OptimConfig, Reporter, ReporterConfig From a20d4ca4504f78dec4f6c5f7c52c729eb075c8aa Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 5 Apr 2023 00:26:54 +0000 Subject: [PATCH 04/43] Support encoder-decoder model LM output --- elk/evaluation/evaluate.py | 2 +- elk/extraction/extraction.py | 92 ++++++++++++++++++++---------------- elk/training/train.py | 16 +++++-- elk/training/train_log.py | 9 ++-- elk/utils/__init__.py | 2 +- elk/utils/hf_utils.py | 45 ++++++++++-------- 6 files changed, 96 insertions(+), 70 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index ec728644..27467ec2 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -7,7 +7,7 @@ from simple_parsing.helpers import Serializable, field from elk.evaluation.evaluate_log import EvalLog -from elk.extraction.extraction import Extract +from elk.extraction import Extract from elk.run import Run from elk.training import Reporter diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index ac3c4a4a..0546b22b 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,7 +1,6 @@ """Functions for extracting the hidden states of a model.""" from dataclasses import InitVar, dataclass from datasets import ( - Array2D, Array3D, ClassLabel, DatasetDict, @@ -14,17 +13,22 @@ ) from itertools import islice from simple_parsing import Serializable, field -from transformers import AutoConfig, AutoTokenizer, PreTrainedModel +from torch import Tensor +from transformers import AutoConfig, AutoTokenizer +from transformers.modeling_outputs import Seq2SeqLMOutput from typing import Iterable, Literal, Optional, Union import logging import os import torch +# import torch.nn.functional as F + from ..utils import ( assert_type, convert_span, float32_to_int16, - get_model_class, + instantiate_model, + is_autoregressive, select_train_val_splits, select_usable_devices, ) @@ -101,30 +105,12 @@ def extract_hiddens( world_size=world_size, ) # this dataset is already sharded, but hasn't been truncated to max_examples - model_cls = get_model_class(cfg.model) - model = assert_type( - PreTrainedModel, - model_cls.from_pretrained( - cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32 - ), + model = instantiate_model( + cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32 ).to(device) - # TODO: Maybe also make this configurable? - # We want to make sure the answer is never truncated tokenizer = AutoTokenizer.from_pretrained( cfg.model, truncation_side="left", verbose=False ) - is_enc_dec = model.config.is_encoder_decoder - - # If this is an encoder-decoder model we don't need to run the decoder at all. - # Just strip it off, making the problem equivalent to a regular encoder-only model. - if is_enc_dec: - # This isn't actually *guaranteed* by HF, but it's true for all existing models - if not hasattr(model, "get_encoder") or not callable(model.get_encoder): - raise ValueError( - "Encoder-decoder model doesn't have expected get_encoder() method" - ) - - model = assert_type(PreTrainedModel, model.get_encoder()) # Iterating over questions layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) @@ -148,7 +134,7 @@ def extract_hiddens( ) for layer_idx in layer_indices } - model_preds = torch.empty( + lm_preds = torch.empty( num_variants, 2, # contrast pair device=device, @@ -165,29 +151,51 @@ def extract_hiddens( text = choice["text"] variant_inputs.append(text) + # TODO: Do something smarter than "rindex" here. Really we want to + # get the span of the answer directly from Jinja, but that doesn't + # seem possible. This approach may fail for complex templates. + answer_start = text.rindex(choice["answer"]) + + # Only feed question, not the answer, to the encoder for enc-dec models + if model.config.is_encoder_decoder: + # TODO: Maybe make this more generic for complex templates? + text = text[:answer_start].rstrip() + target = choice["answer"] + else: + target = None + inputs = tokenizer( text, return_offsets_mapping=True, return_tensors="pt", + text_target=target, # type: ignore[arg-type] truncation=True, - ).to(device) + ) # The offset_mapping is a sorted list of (start, end) tuples. We locate # the start of the answer in the tokenized sequence with binary search. offsets = inputs.pop("offset_mapping").squeeze().tolist() + inputs = inputs.to(device) + # Run the forward pass outputs = model(**inputs, output_hidden_states=True) - # TODO: Do something smarter than "rindex" here. Really we'd like to - # get the span of the answer directly from Jinja, but that doesn't seem - # to be supported. The current approach may fail for complex templates. - answer_start = text.rindex(choice["answer"]) - start, end = convert_span( - offsets, (answer_start, answer_start + len(choice["answer"])) - ) - log_p = outputs.logits[..., start - 1 : end - 1, :].log_softmax(dim=-1) - tokens = inputs.input_ids[..., start:end, None] - model_preds[i, j] = log_p.gather(-1, tokens).sum() + # Compute the log probability of the answer tokens if available + if type(outputs).__name__.startswith("CausalLMOutput"): + start, end = convert_span( + offsets, (answer_start, answer_start + len(choice["answer"])) + ) + log_p = outputs.logits[..., start - 1 : end - 1, :].log_softmax( + dim=-1 + ) + tokens = inputs.input_ids[..., start:end, None] + lm_preds[i, j] = log_p.gather(-1, tokens).sum() + + elif isinstance(outputs, Seq2SeqLMOutput): + # The cross entropy loss is averaged over tokens, so we need to + # multiply by the length to get the total log probability. + length = inputs.labels.shape[-1] + lm_preds[i, j] = -assert_type(Tensor, outputs.loss) * length hiddens = ( outputs.get("decoder_hidden_states") or outputs["hidden_states"] @@ -216,7 +224,7 @@ def extract_hiddens( yield dict( label=example["label"], # We only need the probability of the positive example since this is binary - model_preds=model_preds.softmax(dim=-1)[..., 1], + model_preds=lm_preds.softmax(dim=-1)[..., 1], variant_ids=example["template_names"], text_inputs=text_inputs, **hidden_dict, @@ -269,10 +277,6 @@ def get_splits() -> SplitDict: length=num_variants, ), "label": ClassLabel(names=["neg", "pos"]), - "model_preds": Sequence( - Value(dtype="float32"), - length=num_variants, - ), "text_inputs": Sequence( Sequence( Value(dtype="string"), @@ -281,6 +285,14 @@ def get_splits() -> SplitDict: length=num_variants, ), } + + # Only add model_preds if the model is an autoregressive model + if is_autoregressive(model_cfg): + other_cols["model_preds"] = Sequence( + Value(dtype="float32"), + length=num_variants, + ) + devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem) builders = { split_name: _GeneratorBuilder( diff --git a/elk/training/train.py b/elk/training/train.py index ad033757..dc350611 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -144,15 +144,21 @@ def train_reporter( ) reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - val_gt_cpu = val_gt.repeat_interleave(val_lm_preds.shape[1]).float().cpu() + if val_lm_preds is not None: + val_gt_cpu = val_gt.repeat_interleave(val_lm_preds.shape[1]).float().cpu() + val_lm_auroc = float(roc_auc_score(val_gt_cpu, val_lm_preds.flatten())) + val_lm_acc = float(accuracy_score(val_gt_cpu, val_lm_preds.flatten() > 0.5)) + else: + val_lm_auroc = None + val_lm_acc = None stats = ElicitLog( layer=layer, pseudo_auroc=pseudo_auroc, train_loss=train_loss, eval_result=val_result, - lm_auroc=float(roc_auc_score(val_gt_cpu, val_lm_preds.flatten())), - lm_acc=float(accuracy_score(val_gt_cpu, val_lm_preds.flatten() > 0.5)), + lm_auroc=val_lm_auroc, + lm_acc=val_lm_acc, ) if not self.cfg.skip_baseline: @@ -165,8 +171,8 @@ def train_reporter( val_gt, device, ) - stats.lr_auroc = lr_auroc - stats.lr_acc = lr_acc + stats.lr_auroc = float(lr_auroc) + stats.lr_acc = float(lr_acc) self.save_baseline(lr_dir, layer, lr_model) with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: diff --git a/elk/training/train_log.py b/elk/training/train_log.py index 33ecd72b..34c0d250 100644 --- a/elk/training/train_log.py +++ b/elk/training/train_log.py @@ -12,8 +12,9 @@ class ElicitLog: train_loss: float eval_result: EvalResult - lm_auroc: float - lm_acc: float + # Only available when the LM is autoregressive + lm_auroc: Optional[float] = None + lm_acc: Optional[float] = None # Only available if reporting baseline lr_auroc: Optional[float] = None @@ -48,9 +49,9 @@ def to_csv_line(self, skip_baseline: bool) -> list[str]: self.eval_result.cal_acc, self.eval_result.auroc, self.eval_result.ece, - self.lm_auroc, - self.lm_acc, ] + if self.lm_auroc is not None and self.lm_acc is not None: + items += [self.lm_auroc, self.lm_acc] if not skip_baseline: items += [self.lr_auroc, self.lr_acc] diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index 5dfda1aa..6d371243 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -8,6 +8,6 @@ ) from .gpu_utils import select_usable_devices -from .hf_utils import get_model_class +from .hf_utils import instantiate_model, is_autoregressive from .tree_utils import pytree_map from .typing import assert_type, float32_to_int16, int16_to_float32 diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index fc3ef60b..5ff82902 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -1,32 +1,39 @@ from .typing import assert_type -from transformers import AutoConfig, PreTrainedModel -from typing import Type +from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel import transformers -def get_model_class(model_str: str) -> Type[PreTrainedModel]: - """Get the appropriate model class for a model string.""" +# Ordered by preference +_AUTOREGRESSIVE_SUFFIXES = [ + # Encoder-decoder models + "ConditionalGeneration", + # Autoregressive models + "CausalLM", + "LMHeadModel", +] + + +def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel: + """Instantiate a model string with the appropriate `Auto` class.""" model_cfg = AutoConfig.from_pretrained(model_str) archs = assert_type(list, model_cfg.architectures) - # Ordered by preference - suffixes = [ - # Fine-tuned for classification - "SequenceClassification", - # Encoder-decoder models - "ConditionalGeneration", - # Autoregressive models - "CausalLM", - "LMHeadModel", - ] - - for suffix in suffixes: + for suffix in _AUTOREGRESSIVE_SUFFIXES: # Check if any of the architectures in the config end with the suffix. # If so, return the corresponding model class. for arch_str in archs: if arch_str.endswith(suffix): - return getattr(transformers, arch_str) + model_cls = getattr(transformers, arch_str) + return model_cls.from_pretrained(model_str, **kwargs) + + return AutoModel.from_pretrained(model_str, **kwargs) - raise ValueError( - f"'{model_str}' does not have any supported architectures: {archs}" + +def is_autoregressive(model_cfg: PretrainedConfig) -> bool: + """Check if a model config is autoregressive.""" + archs = assert_type(list, model_cfg.architectures) + return any( + arch_str.endswith(suffix) + for arch_str in archs + for suffix in _AUTOREGRESSIVE_SUFFIXES ) From 77d74185df510e329ccafff789a733d37934dd0c Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 5 Apr 2023 08:39:24 +0000 Subject: [PATCH 05/43] isort --- elk/__init__.py | 2 +- elk/calibration.py | 5 +- elk/eigsh.py | 3 +- elk/evaluation/evaluate.py | 30 +++-- elk/evaluation/evaluate_log.py | 27 ---- elk/extraction/__init__.py | 4 +- elk/extraction/balanced_sampler.py | 12 +- elk/extraction/extraction.py | 12 +- elk/extraction/generator.py | 2 +- elk/extraction/prompt_loading.py | 22 ++-- elk/files.py | 4 +- elk/logging.py | 1 + elk/math_util.py | 3 +- elk/parsing.py | 1 + elk/promptsource/templates.py | 12 +- elk/run.py | 52 ++++---- elk/training/__init__.py | 1 - elk/training/ccs_reporter.py | 18 +-- elk/training/classifier.py | 9 +- elk/training/eigen_reporter.py | 10 +- elk/training/losses.py | 5 +- elk/training/preprocessing.py | 1 + elk/training/reporter.py | 12 +- elk/training/train.py | 41 +++--- elk/training/train_log.py | 60 --------- elk/utils/__init__.py | 1 - elk/utils/csv.py | 37 ------ elk/utils/data_utils.py | 14 +- elk/utils/gpu_utils.py | 10 +- elk/utils/hf_utils.py | 4 +- elk/utils/tree_utils.py | 1 - elk/utils/typing.py | 3 +- pyproject.toml | 4 +- tests/test_classifier.py | 5 +- tests/test_convert_span.py | 8 +- tests/test_eigen_reporter.py | 3 +- tests/test_eigsh.py | 5 +- tests/test_load_prompts.py | 6 +- tests/test_math.py | 9 +- tests/test_samplers.py | 8 +- tests/test_write_iterator_to_file.py | 186 --------------------------- 41 files changed, 183 insertions(+), 470 deletions(-) delete mode 100644 elk/evaluation/evaluate_log.py delete mode 100644 elk/training/train_log.py delete mode 100644 elk/utils/csv.py delete mode 100644 tests/test_write_iterator_to_file.py diff --git a/elk/__init__.py b/elk/__init__.py index bb9f9b17..9d95a485 100644 --- a/elk/__init__.py +++ b/elk/__init__.py @@ -1,3 +1,3 @@ -from .extraction import extract_hiddens, Extract +from .extraction import Extract, extract_hiddens __all__ = ["extract_hiddens", "Extract"] diff --git a/elk/calibration.py b/elk/calibration.py index 3d494872..db56fa02 100644 --- a/elk/calibration.py +++ b/elk/calibration.py @@ -1,8 +1,9 @@ +import warnings from dataclasses import dataclass, field -from torch import Tensor from typing import NamedTuple + import torch -import warnings +from torch import Tensor class CalibrationEstimate(NamedTuple): diff --git a/elk/eigsh.py b/elk/eigsh.py index 10c1de60..c7277e32 100644 --- a/elk/eigsh.py +++ b/elk/eigsh.py @@ -1,7 +1,8 @@ -from torch import Tensor from typing import Literal, Optional + import torch import torch.nn.functional as F +from torch import Tensor def lanczos_eigsh( diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 92f18d22..59375e75 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,15 +1,17 @@ -from ..extraction import Extract -from ..files import elk_reporter_dir -from ..run import Run -from ..training import Reporter -from ..utils import select_usable_devices -from .evaluate_log import EvalLog from dataclasses import dataclass from functools import partial from pathlib import Path -from simple_parsing.helpers import Serializable, field from typing import Callable, Literal, Optional + +import pandas as pd import torch +from simple_parsing.helpers import Serializable, field + +from ..extraction import Extract +from ..files import elk_reporter_dir +from ..run import Run +from ..training import Reporter +from ..utils import select_usable_devices @dataclass @@ -52,7 +54,7 @@ class Evaluate(Run): def evaluate_reporter( self, layer: int, devices: list[str], world_size: int = 1 - ) -> EvalLog: + ) -> pd.Series: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) @@ -73,9 +75,11 @@ def evaluate_reporter( test_x1, ) - return EvalLog( - layer=layer, - eval_result=test_result, + return pd.Series( + { + "layer": layer, + **test_result._asdict(), + } ) def evaluate(self): @@ -85,12 +89,10 @@ def evaluate(self): ) num_devices = len(devices) - func: Callable[[int], EvalLog] = partial( + func: Callable[[int], pd.Series] = partial( self.evaluate_reporter, devices=devices, world_size=num_devices ) self.apply_to_layers( func=func, num_devices=num_devices, - to_csv_line=lambda item: item.to_csv_line(), - csv_columns=EvalLog.csv_columns(), ) diff --git a/elk/evaluation/evaluate_log.py b/elk/evaluation/evaluate_log.py deleted file mode 100644 index 2cb9b485..00000000 --- a/elk/evaluation/evaluate_log.py +++ /dev/null @@ -1,27 +0,0 @@ -from dataclasses import dataclass - -from elk.training.reporter import EvalResult - - -@dataclass -class EvalLog: - """The result of running eval on a layer of a dataset""" - - layer: int - eval_result: EvalResult - - @staticmethod - def csv_columns() -> list[str]: - return ["layer", "acc", "cal_acc", "auroc", "ece"] - - def to_csv_line(self) -> list[str]: - items = [ - self.layer, - self.eval_result.acc, - self.eval_result.cal_acc, - self.eval_result.auroc, - self.eval_result.ece, - ] - return [ - f"{item:.4f}" if isinstance(item, float) else str(item) for item in items - ] diff --git a/elk/extraction/__init__.py b/elk/extraction/__init__.py index 8b00ea3d..fac876fa 100644 --- a/elk/extraction/__init__.py +++ b/elk/extraction/__init__.py @@ -1,6 +1,6 @@ from .balanced_sampler import BalancedSampler, FewShotSampler -from .extraction import Extract, extract_hiddens, extract -from .generator import _GeneratorConfig, _GeneratorBuilder +from .extraction import Extract, extract, extract_hiddens +from .generator import _GeneratorBuilder, _GeneratorConfig from .prompt_loading import PromptConfig, load_prompts __all__ = [ diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index e472684b..2ea4815e 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -1,12 +1,14 @@ -from ..math_util import stochastic_round_constrained -from ..utils import infer_label_column -from ..utils.typing import assert_type from collections import deque -from datasets import IterableDataset, Features from itertools import cycle from random import Random +from typing import Iterable, Iterator, Optional + +from datasets import Features, IterableDataset from torch.utils.data import IterableDataset as TorchIterableDataset -from typing import Iterator, Optional, Iterable + +from ..math_util import stochastic_round_constrained +from ..utils import infer_label_column +from ..utils.typing import assert_type class BalancedSampler(TorchIterableDataset): diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 0546b22b..bc922242 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,5 +1,11 @@ """Functions for extracting the hidden states of a model.""" +import logging +import os from dataclasses import InitVar, dataclass +from itertools import islice +from typing import Iterable, Literal, Optional, Union + +import torch from datasets import ( Array3D, ClassLabel, @@ -11,18 +17,12 @@ Value, get_dataset_config_info, ) -from itertools import islice from simple_parsing import Serializable, field from torch import Tensor from transformers import AutoConfig, AutoTokenizer from transformers.modeling_outputs import Seq2SeqLMOutput -from typing import Iterable, Literal, Optional, Union -import logging -import os -import torch # import torch.nn.functional as F - from ..utils import ( assert_type, convert_span, diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index fbf10848..fb4d03bc 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Callable, Optional, Any, Dict +from typing import Any, Callable, Dict, Optional import datasets from datasets.splits import NamedSplit diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 02471a9d..a494be5a 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -1,3 +1,15 @@ +from dataclasses import dataclass +from random import Random +from typing import Any, Iterator, Literal, Optional + +from datasets import ( + Dataset, + Features, + load_dataset, +) +from datasets.distributed import split_dataset_by_node +from simple_parsing.helpers import Serializable, field + from ..promptsource import DatasetTemplates from ..utils import ( assert_type, @@ -7,16 +19,6 @@ select_train_val_splits, ) from .balanced_sampler import FewShotSampler -from dataclasses import dataclass -from datasets import ( - load_dataset, - Dataset, - Features, -) -from datasets.distributed import split_dataset_by_node -from random import Random -from simple_parsing.helpers import field, Serializable -from typing import Any, Iterator, Literal, Optional @dataclass diff --git a/elk/files.py b/elk/files.py index 47225876..84226706 100644 --- a/elk/files.py +++ b/elk/files.py @@ -1,13 +1,13 @@ """Helper functions for dealing with files.""" -from pathlib import Path import json import os import random +from pathlib import Path from typing import Optional -from simple_parsing import Serializable import yaml +from simple_parsing import Serializable def elk_reporter_dir() -> Path: diff --git a/elk/logging.py b/elk/logging.py index 19bb12c3..706055bd 100644 --- a/elk/logging.py +++ b/elk/logging.py @@ -1,4 +1,5 @@ import logging + from .utils import select_train_val_splits diff --git a/elk/math_util.py b/elk/math_util.py index 7b5cd38c..4ae9daee 100644 --- a/elk/math_util.py +++ b/elk/math_util.py @@ -1,7 +1,8 @@ -from torch import Tensor import math import random + import torch +from torch import Tensor @torch.jit.script diff --git a/elk/parsing.py b/elk/parsing.py index c40a8473..1daded78 100644 --- a/elk/parsing.py +++ b/elk/parsing.py @@ -1,4 +1,5 @@ import re + from .training.losses import LOSSES diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index a19526e1..8f1828f8 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -1,14 +1,14 @@ -from collections import Counter, defaultdict -from jinja2 import BaseLoader, Environment, meta -from pathlib import Path -from shutil import rmtree -from typing import Optional import logging import os import random import uuid -import yaml +from collections import Counter, defaultdict +from pathlib import Path +from shutil import rmtree +from typing import Optional +import yaml +from jinja2 import BaseLoader, Environment, meta # Truncation of jinja template variables # 1710 = 300 words x 4.7 avg characters per word + 300 spaces diff --git a/elk/run.py b/elk/run.py index 173750fb..ddcfac80 100644 --- a/elk/run.py +++ b/elk/run.py @@ -5,29 +5,29 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Callable, Optional, Union, - Callable, - Iterator, ) import numpy as np +import pandas as pd import torch import torch.multiprocessing as mp from datasets import DatasetDict from torch import Tensor from tqdm import tqdm -from elk.extraction.extraction import extract -from elk.files import create_output_directory, save_config, save_meta -from elk.training.preprocessing import normalize -from elk.utils.csv import write_iterator_to_file, Log -from elk.utils.data_utils import get_layers, select_train_val_splits -from elk.utils.typing import assert_type, int16_to_float32 +from .extraction import extract +from .files import create_output_directory, save_config, save_meta +from .logging import save_debug_log +from .training.preprocessing import normalize +from .utils import assert_type, int16_to_float32 +from .utils.data_utils import get_layers, select_train_val_splits if TYPE_CHECKING: - from elk.evaluation.evaluate import Eval - from elk.training.train import Elicit + from .evaluation.evaluate import Eval + from .training.train import Elicit @dataclass @@ -103,10 +103,8 @@ def concatenate(self, layers): def apply_to_layers( self, - func: Callable[[int], Log], + func: Callable[[int], pd.Series], num_devices: int, - to_csv_line: Callable[[Log], list[str]], - csv_columns: list[str], ): """Apply a function to each layer of the dataset in parallel and writes the results to a CSV file. @@ -115,10 +113,7 @@ def apply_to_layers( func: The function to apply to each layer. The int is the index of the layer. num_devices: The number of devices to use. - to_csv_line: A function that converts a Log to a list of strings. - This has to be injected in because the Run class does not know - the extra options e.g. skip_baseline to apply to function. - csv_columns: The columns of the CSV file.""" + """ self.out_dir = assert_type(Path, self.out_dir) layers: list[int] = get_layers(self.dataset) @@ -129,15 +124,14 @@ def apply_to_layers( # Should we write to different CSV files for elicit vs eval? with mp.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: mapper = pool.imap_unordered if num_devices > 1 else map - iterator: Iterator[Log] = tqdm( # type: ignore - mapper(func, layers), total=len(layers) - ) - write_iterator_to_file( - iterator=iterator, - file=f, - debug=self.cfg.debug, - dataset=self.dataset, - out_dir=self.out_dir, - csv_columns=csv_columns, - to_csv_line=to_csv_line, - ) + row_buf = [] + + try: + for row in tqdm(mapper(func, layers), total=len(layers)): + row_buf.append(row) + finally: + # Make sure the CSV is written even if we crash or get interrupted + df = pd.DataFrame(row_buf).sort_values(by="layer") + df.to_csv(f, index=False) + if self.cfg.debug: + save_debug_log(self.dataset, self.out_dir) diff --git a/elk/training/__init__.py b/elk/training/__init__.py index a9d76f05..41264179 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -2,7 +2,6 @@ from .eigen_reporter import EigenReporter, EigenReporterConfig from .reporter import OptimConfig, Reporter, ReporterConfig - __all__ = [ "Reporter", "ReporterConfig", diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 35ee67ec..258f2bc4 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -1,17 +1,19 @@ """An ELK reporter network.""" -from ..parsing import parse_loss -from ..utils.typing import assert_type -from .losses import LOSSES -from .reporter import Reporter, ReporterConfig +import math from copy import deepcopy from dataclasses import dataclass, field -from torch import Tensor -from torch.nn.functional import binary_cross_entropy as bce -from typing import cast, Literal, Optional -import math +from typing import Literal, Optional, cast + import torch import torch.nn as nn +from torch import Tensor +from torch.nn.functional import binary_cross_entropy as bce + +from ..parsing import parse_loss +from ..utils.typing import assert_type +from .losses import LOSSES +from .reporter import Reporter, ReporterConfig @dataclass diff --git a/elk/training/classifier.py b/elk/training/classifier.py index cab88dd9..726cae7a 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -1,11 +1,14 @@ from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor from torch.nn.functional import ( binary_cross_entropy_with_logits as bce_with_logits, +) +from torch.nn.functional import ( cross_entropy, ) -from torch import Tensor -from typing import Optional -import torch @dataclass diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 77582f9f..62e0a543 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -1,12 +1,14 @@ """An ELK reporter network.""" -from ..math_util import cov_mean_fused -from ..eigsh import lanczos_eigsh -from .reporter import Reporter, ReporterConfig from dataclasses import dataclass -from torch import nn, optim, Tensor from typing import Optional + import torch +from torch import Tensor, nn, optim + +from ..eigsh import lanczos_eigsh +from ..math_util import cov_mean_fused +from .reporter import Reporter, ReporterConfig @dataclass diff --git a/elk/training/losses.py b/elk/training/losses.py index d91c1e79..8d7e287b 100644 --- a/elk/training/losses.py +++ b/elk/training/losses.py @@ -1,10 +1,11 @@ """Loss functions for training reporters.""" -from torch import Tensor -import torch import warnings from inspect import signature +import torch +from torch import Tensor + LOSSES = dict() # Registry of loss functions diff --git a/elk/training/preprocessing.py b/elk/training/preprocessing.py index 2802bb9e..6081dcbb 100644 --- a/elk/training/preprocessing.py +++ b/elk/training/preprocessing.py @@ -1,6 +1,7 @@ """Preprocessing functions for training.""" from typing import Literal + import torch diff --git a/elk/training/reporter.py b/elk/training/reporter.py index ea8a4406..cf3e21e2 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -1,16 +1,18 @@ """An ELK reporter network.""" -from ..calibration import CalibrationError -from .classifier import Classifier from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from simple_parsing.helpers import Serializable -from sklearn.metrics import roc_auc_score -from torch import Tensor from typing import Literal, NamedTuple, Optional, Union + import torch import torch.nn as nn +from simple_parsing.helpers import Serializable +from sklearn.metrics import roc_auc_score +from torch import Tensor + +from ..calibration import CalibrationError +from .classifier import Classifier class EvalResult(NamedTuple): diff --git a/elk/training/train.py b/elk/training/train.py index dc350611..03217d98 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -5,8 +5,9 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Literal, Optional, Callable +from typing import Callable, Literal, Optional +import pandas as pd import torch from simple_parsing import Serializable, field, subgroups from sklearn.metrics import accuracy_score, roc_auc_score @@ -16,12 +17,11 @@ from elk.run import Run from elk.utils.typing import assert_type +from ..utils import select_usable_devices from .ccs_reporter import CcsReporter, CcsReporterConfig from .classifier import Classifier from .eigen_reporter import EigenReporter, EigenReporterConfig from .reporter import OptimConfig, Reporter, ReporterConfig -from .train_log import ElicitLog -from ..utils import select_usable_devices @dataclass @@ -118,7 +118,7 @@ def train_reporter( layer: int, devices: list[str], world_size: int = 1, - ) -> ElicitLog: + ) -> pd.Series: """Train a single reporter on a single layer.""" self.make_reproducible(seed=self.cfg.net.seed + layer) @@ -152,13 +152,15 @@ def train_reporter( val_lm_auroc = None val_lm_acc = None - stats = ElicitLog( - layer=layer, - pseudo_auroc=pseudo_auroc, - train_loss=train_loss, - eval_result=val_result, - lm_auroc=val_lm_auroc, - lm_acc=val_lm_acc, + row = pd.Series( + { + "layer": layer, + "pseudo_auroc": pseudo_auroc, + "train_loss": train_loss, + "eval_result": val_result, + "lm_auroc": val_lm_auroc, + "lm_acc": val_lm_acc, + } ) if not self.cfg.skip_baseline: @@ -171,14 +173,14 @@ def train_reporter( val_gt, device, ) - stats.lr_auroc = float(lr_auroc) - stats.lr_acc = float(lr_acc) + row["lr_auroc"] = lr_auroc + row["lr_acc"] = lr_acc self.save_baseline(lr_dir, layer, lr_model) with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: torch.save(reporter, file) - return stats + return row def get_pseudo_auroc( self, layer: int, x0: Tensor, x1: Tensor, val_x0: Tensor, val_x1: Tensor @@ -202,14 +204,7 @@ def train(self): """Train a reporter on each layer of the network.""" devices = select_usable_devices(self.cfg.num_gpus) num_devices = len(devices) - func: Callable[[int], ElicitLog] = partial( + func: Callable[[int], pd.Series] = partial( self.train_reporter, devices=devices, world_size=num_devices ) - self.apply_to_layers( - func=func, - num_devices=num_devices, - to_csv_line=lambda item: item.to_csv_line( - skip_baseline=self.cfg.skip_baseline - ), - csv_columns=ElicitLog.csv_columns(self.cfg.skip_baseline), - ) + self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/training/train_log.py b/elk/training/train_log.py deleted file mode 100644 index 34c0d250..00000000 --- a/elk/training/train_log.py +++ /dev/null @@ -1,60 +0,0 @@ -from .reporter import EvalResult -from dataclasses import dataclass -from typing import Optional - - -@dataclass -class ElicitLog: - """The result of running elicit on a layer of a dataset""" - - layer: int - pseudo_auroc: float - train_loss: float - eval_result: EvalResult - - # Only available when the LM is autoregressive - lm_auroc: Optional[float] = None - lm_acc: Optional[float] = None - - # Only available if reporting baseline - lr_auroc: Optional[float] = None - # Only available if reporting baseline - lr_acc: Optional[float] = None - - @staticmethod - def csv_columns(skip_baseline: bool) -> list[str]: - """Return a CSV header with the column names.""" - cols = [ - "layer", - "pseudo_auroc", - "train_loss", - "acc", - "cal_acc", - "auroc", - "ece", - "lm_auroc", - "lm_acc", - ] - if not skip_baseline: - cols += ["lr_auroc", "lr_acc"] - return cols - - def to_csv_line(self, skip_baseline: bool) -> list[str]: - """Return a CSV line with the evaluation results.""" - items = [ - self.layer, - self.pseudo_auroc, - self.train_loss, - self.eval_result.acc, - self.eval_result.cal_acc, - self.eval_result.auroc, - self.eval_result.ece, - ] - if self.lm_auroc is not None and self.lm_acc is not None: - items += [self.lm_auroc, self.lm_acc] - if not skip_baseline: - items += [self.lr_auroc, self.lr_acc] - - return [ - f"{item:.4f}" if isinstance(item, float) else str(item) for item in items - ] diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index 9a5f9082..1400a98d 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -6,7 +6,6 @@ infer_num_classes, select_train_val_splits, ) - from .gpu_utils import select_usable_devices from .hf_utils import instantiate_model, is_autoregressive from .tree_utils import pytree_map diff --git a/elk/utils/csv.py b/elk/utils/csv.py deleted file mode 100644 index 9828e269..00000000 --- a/elk/utils/csv.py +++ /dev/null @@ -1,37 +0,0 @@ -from ..evaluation.evaluate_log import EvalLog -from ..logging import save_debug_log -from ..training.train_log import ElicitLog -from datasets import DatasetDict -from pathlib import Path -from typing import Iterator, Callable, TextIO, TypeVar -import csv - -"""A generic log type that contains a layer field -The layer field is used to sort the logs by layer.""" -Log = TypeVar("Log", EvalLog, ElicitLog) - - -def write_iterator_to_file( - iterator: Iterator[Log], - csv_columns: list[str], - to_csv_line: Callable[[Log], list[str]], - file: TextIO, - debug: bool, - dataset: DatasetDict, - out_dir: Path, -) -> None: - row_buf = [] - writer = csv.writer(file) - # write a single line - writer.writerow(csv_columns) - try: - for row in iterator: - row_buf.append(row) - finally: - # Make sure the CSV is written even if we crash or get interrupted - sorted_by_layer = sorted(row_buf, key=lambda x: x.layer) - for row in sorted_by_layer: - row = to_csv_line(row) - writer.writerow(row) - if debug: - save_debug_log(dataset, out_dir) diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 848f0d4e..f8129145 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -1,6 +1,9 @@ -from .typing import assert_type -from ..promptsource.templates import Template +import copy from bisect import bisect_left, bisect_right +from operator import itemgetter +from random import Random +from typing import Any, Iterable, List + from datasets import ( ClassLabel, DatasetDict, @@ -8,10 +11,9 @@ Split, Value, ) -from operator import itemgetter -from random import Random -from typing import Iterable, List, Any -import copy + +from ..promptsource.templates import Template +from .typing import assert_type def convert_span( diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index 9074bb8a..f305c6c1 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -1,12 +1,14 @@ """Utilities that use PyNVML to get GPU usage info, and select GPUs accordingly.""" -from .typing import assert_type -from typing import Optional import os +import time +import warnings +from typing import Optional + import pynvml import torch -import warnings -import time + +from .typing import assert_type def select_usable_devices( diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 5ff82902..4c3ab331 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -1,7 +1,7 @@ -from .typing import assert_type -from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel import transformers +from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel +from .typing import assert_type # Ordered by preference _AUTOREGRESSIVE_SUFFIXES = [ diff --git a/elk/utils/tree_utils.py b/elk/utils/tree_utils.py index 9a5b0fb1..f084874f 100644 --- a/elk/utils/tree_utils.py +++ b/elk/utils/tree_utils.py @@ -6,7 +6,6 @@ from typing import Callable, Mapping, TypeVar - TreeType = TypeVar("TreeType") diff --git a/elk/utils/typing.py b/elk/utils/typing.py index ea552c83..1d38040e 100644 --- a/elk/utils/typing.py +++ b/elk/utils/typing.py @@ -1,8 +1,7 @@ -from typing import cast, Any, Type, TypeVar +from typing import Any, Type, TypeVar, cast import torch - T = TypeVar("T") diff --git a/pyproject.toml b/pyproject.toml index e22a3513..0ae833b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,9 +57,9 @@ testpaths = ["tests"] include = ["elk*"] [tool.ruff] -# Enable pycodestyle (`E`) and Pyflakes (`F`) codes +# Enable pycodestyle (`E`), Pyflakes (`F`), isort (`I`) codes # See https://beta.ruff.rs/docs/rules/ for more possible rules -select = ["E", "F"] +select = ["E", "F", "I"] # Same as Black. line-length = 88 # Avoid automatically removing unused imports in __init__.py files. diff --git a/tests/test_classifier.py b/tests/test_classifier.py index 85c3c57d..bdc9023d 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -1,7 +1,8 @@ -from elk.training.classifier import Classifier +import torch from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression -import torch + +from elk.training.classifier import Classifier @torch.no_grad() diff --git a/tests/test_convert_span.py b/tests/test_convert_span.py index 10e8c583..2a652f0a 100644 --- a/tests/test_convert_span.py +++ b/tests/test_convert_span.py @@ -1,7 +1,9 @@ -from elk.utils import convert_span -from hypothesis import given, strategies as st -from transformers import AutoTokenizer import pytest +from hypothesis import given +from hypothesis import strategies as st +from transformers import AutoTokenizer + +from elk.utils import convert_span # Define a fixture with session scope that initializes the tokenizer diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index 9d548f67..58dd6c13 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -1,6 +1,7 @@ +import torch + from elk.math_util import batch_cov, cov_mean_fused from elk.training import EigenReporter, EigenReporterConfig -import torch def test_eigen_reporter(): diff --git a/tests/test_eigsh.py b/tests/test_eigsh.py index b208dc23..c34ce2b3 100644 --- a/tests/test_eigsh.py +++ b/tests/test_eigsh.py @@ -1,8 +1,9 @@ -from elk.eigsh import lanczos_eigsh -from scipy.sparse.linalg import eigsh import numpy as np import pytest import torch +from scipy.sparse.linalg import eigsh + +from elk.eigsh import lanczos_eigsh @pytest.mark.parametrize("n", [20, 40]) diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py index c9a45f03..a5d238fd 100644 --- a/tests/test_load_prompts.py +++ b/tests/test_load_prompts.py @@ -1,9 +1,11 @@ -from elk.extraction import load_prompts, PromptConfig -from elk.promptsource.templates import DatasetTemplates from itertools import cycle, islice from typing import Literal + import pytest +from elk.extraction import PromptConfig, load_prompts +from elk.promptsource.templates import DatasetTemplates + @pytest.mark.filterwarnings("ignore:Unable to find a decoding function") def test_load_prompts(): diff --git a/tests/test_math.py b/tests/test_math.py index c6db4e93..ee81914e 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,9 +1,12 @@ -from elk.math_util import batch_cov, cov_mean_fused, stochastic_round_constrained -from hypothesis import given, strategies as st -from random import Random import math +from random import Random + import numpy as np import torch +from hypothesis import given +from hypothesis import strategies as st + +from elk.math_util import batch_cov, cov_mean_fused, stochastic_round_constrained def test_cov_mean_fused(): diff --git a/tests/test_samplers.py b/tests/test_samplers.py index cb5e1225..87c1ac0c 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -1,10 +1,12 @@ from collections import Counter -from datasets import load_dataset, IterableDataset -from elk.extraction import FewShotSampler, BalancedSampler -from elk.utils import assert_type, infer_label_column from itertools import islice from random import Random +from datasets import IterableDataset, load_dataset + +from elk.extraction import BalancedSampler, FewShotSampler +from elk.utils import assert_type, infer_label_column + def test_output_batches_are_balanced(): # Load an example dataset for testing diff --git a/tests/test_write_iterator_to_file.py b/tests/test_write_iterator_to_file.py deleted file mode 100644 index a40d8e2f..00000000 --- a/tests/test_write_iterator_to_file.py +++ /dev/null @@ -1,186 +0,0 @@ -import csv -import time -from pathlib import Path -from typing import Iterator -import multiprocessing as mp - - -from datasets import DatasetDict - -from elk.utils.csv import write_iterator_to_file -from elk.training.reporter import EvalResult -from elk.training.train_log import ElicitLog - - -def test_write_iterator_to_file(tmp_path: Path): - items: list[ElicitLog] = [ - ElicitLog( - layer=1, - train_loss=1.0, - eval_result=EvalResult( - acc=0.0, - ece=0.0, - cal_acc=0.0, - auroc=0.0, - ), - pseudo_auroc=0.0, - ) - ] - iterator = iter(items) - csv_columns = ElicitLog.csv_columns(skip_baseline=True) - - def to_csv_line(x): - return x.to_csv_line(skip_baseline=True) - - # Write the CSV file - with open(tmp_path / "test.csv", "w") as f: - write_iterator_to_file( - iterator=iterator, - file=f, - debug=False, - dataset=DatasetDict(), - out_dir=tmp_path, - csv_columns=csv_columns, - to_csv_line=to_csv_line, - ) - # Read the CSV file - with open(tmp_path / "test.csv", "r") as f: - reader = csv.reader(f) - # read all to lines - lines = list(reader) - # assert that the first line is the header - assert lines[0] == csv_columns - # assert that the second line is the data - assert lines[1] == to_csv_line(items[0]) - - -def test_write_iterator_to_file_crash(tmp_path: Path): - first_layer_log = ElicitLog( - layer=1, - train_loss=1.0, - eval_result=EvalResult( - acc=0.0, - ece=0.0, - cal_acc=0.0, - auroc=0.0, - ), - pseudo_auroc=0.0, - ) - - second_layer_log = ElicitLog( - layer=2, - train_loss=1.0, - eval_result=EvalResult( - acc=0.0, - ece=0.0, - cal_acc=0.0, - auroc=0.0, - ), - pseudo_auroc=0.0, - ) - - def iterator() -> Iterator[ElicitLog]: - for i in range(3): - if i == 0: - yield first_layer_log - elif i == 1: - yield second_layer_log - # on the third iteration, raise an ValueError - # We should still be able to write the first two layers - if i == 2: - raise ValueError() - - csv_columns = ElicitLog.csv_columns(skip_baseline=True) - - def to_csv_line(x): - return x.to_csv_line(skip_baseline=True) - - # Write the CSV file - try: - with open(tmp_path / "test.csv", "w") as f: - write_iterator_to_file( - iterator=iterator(), - file=f, - debug=False, - dataset=DatasetDict(), - out_dir=tmp_path, - csv_columns=csv_columns, - to_csv_line=to_csv_line, - ) - except ValueError: - # We expect a ValueError to be raised, - # and don't want to fail the test - pass - - # Read the CSV file - with open(tmp_path / "test.csv", "r") as f: - reader = csv.reader(f) - # read all to lines - lines = list(reader) - # assert that the first line is the header - assert lines[0] == csv_columns - # assert that the second line has the first layer - assert lines[1] == to_csv_line(first_layer_log) - # assert that the third line has the second layer - assert lines[2] == to_csv_line(second_layer_log) - - -def log_function(layer: int) -> ElicitLog: - """ - raise an error on the second layer - This is a top-level function so that it can be pickled - for multiprocessing - """ - if layer == 2: - # let the other processes finish first - time.sleep(3) - # crash the process - raise ValueError() - return ElicitLog( - layer=layer, - train_loss=1.0, - eval_result=EvalResult( - acc=0.0, - ece=0.0, - cal_acc=0.0, - auroc=0.0, - ), - pseudo_auroc=0.0, - ) - - -def test_write_iterator_crash_multiprocessing(tmp_path: Path): - processes = 3 - csv_columns = ElicitLog.csv_columns(skip_baseline=True) - - def to_csv_line(x): - return x.to_csv_line(skip_baseline=True) - - try: - with mp.Pool(processes) as pool, open(tmp_path / "eval.csv", "w") as f: - layers = [1, 2, 3] - iterator = pool.imap_unordered(log_function, layers) - write_iterator_to_file( - iterator=iterator, - file=f, - debug=False, - dataset=DatasetDict(), - out_dir=tmp_path, - csv_columns=csv_columns, - to_csv_line=to_csv_line, - ) - except ValueError: - # We expect a ValueError to be raised, - # and don't want to fail the test - pass - # We should still have results for layer 1, 3, even though layer 2 failed - with open(tmp_path / "eval.csv", "r") as f: - reader = csv.reader(f) - # read all to lines - lines = list(reader) - # assert that the first line is the header - assert lines[0] == csv_columns - # assert that the second line has the first layer - assert lines[1] == to_csv_line(log_function(1)) - # assert that the third line has the third layer - assert lines[2] == to_csv_line(log_function(3)) From 5bf63f4f2a374f1745ff451d8081dccc5aceda5e Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 5 Apr 2023 09:10:22 +0000 Subject: [PATCH 06/43] Bug fixes --- elk/extraction/extraction.py | 12 ++++++++---- elk/run.py | 2 +- elk/training/train.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index bc922242..f3a69170 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -3,7 +3,7 @@ import os from dataclasses import InitVar, dataclass from itertools import islice -from typing import Iterable, Literal, Optional, Union +from typing import Any, Iterable, Literal, Optional, Union import torch from datasets import ( @@ -111,6 +111,7 @@ def extract_hiddens( tokenizer = AutoTokenizer.from_pretrained( cfg.model, truncation_side="left", verbose=False ) + has_lm_preds = is_autoregressive(model.config) # Iterating over questions layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) @@ -221,14 +222,17 @@ def extract_hiddens( text_inputs.append(variant_inputs) - yield dict( + out_record: dict[str, Any] = dict( label=example["label"], - # We only need the probability of the positive example since this is binary - model_preds=lm_preds.softmax(dim=-1)[..., 1], variant_ids=example["template_names"], text_inputs=text_inputs, **hidden_dict, ) + if has_lm_preds: + # We only need the probability of the positive example since this is binary + out_record["model_preds"] = lm_preds.softmax(dim=-1)[..., 1] + + yield out_record # Dataset.from_generator wraps all the arguments in lists, so we unpack them here diff --git a/elk/run.py b/elk/run.py index ddcfac80..533304eb 100644 --- a/elk/run.py +++ b/elk/run.py @@ -89,7 +89,7 @@ def prepare_data( val_x0, val_x1 = val_h.unbind(dim=-2) with self.dataset.formatted_as("numpy"): - val_lm_preds = assert_type(np.ndarray, val["model_preds"]) + val_lm_preds = val["model_preds"] if "model_preds" in val else None return x0, x1, val_x0, val_x1, train_labels, val_labels, val_lm_preds diff --git a/elk/training/train.py b/elk/training/train.py index 03217d98..5b43867f 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -157,7 +157,7 @@ def train_reporter( "layer": layer, "pseudo_auroc": pseudo_auroc, "train_loss": train_loss, - "eval_result": val_result, + **val_result._asdict(), "lm_auroc": val_lm_auroc, "lm_acc": val_lm_acc, } From b89e23c29bd112edacaf2d058c944d330e477845 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 5 Apr 2023 09:25:56 +0000 Subject: [PATCH 07/43] Remove test_log_csv_elements --- tests/test_log_csv_elements.py | 44 ---------------------------------- 1 file changed, 44 deletions(-) delete mode 100644 tests/test_log_csv_elements.py diff --git a/tests/test_log_csv_elements.py b/tests/test_log_csv_elements.py deleted file mode 100644 index 59cb4352..00000000 --- a/tests/test_log_csv_elements.py +++ /dev/null @@ -1,44 +0,0 @@ -from elk.evaluation.evaluate_log import EvalLog -from elk.training.reporter import EvalResult -from elk.training.train_log import ElicitLog - - -def test_eval_log_csv_number_elements(): - log = EvalLog( - layer=1, - eval_result=EvalResult( - acc=1.0, - cal_acc=1.0, - auroc=1.0, - ece=1.0, - ), - ) - csv_columns = EvalLog.csv_columns() - csv_values = log.to_csv_line() - assert len(csv_columns) == len( - csv_values - ), "Number of columns and values should be the same" - - -def test_elicit_log_csv_number_elements(): - log = ElicitLog( - layer=1, - train_loss=1.0, - eval_result=EvalResult( - acc=0.0, - ece=0.0, - cal_acc=0.0, - auroc=0.0, - ), - pseudo_auroc=0.0, - ) - csv_columns = ElicitLog.csv_columns(skip_baseline=True) - csv_values = log.to_csv_line(skip_baseline=True) - assert len(csv_columns) == len( - csv_values - ), "Number of columns and values should be the same" - csv_columns_not_skipped = ElicitLog.csv_columns(skip_baseline=False) - csv_values_not_skipped = log.to_csv_line(skip_baseline=False) - assert len(csv_columns_not_skipped) == len( - csv_values_not_skipped - ), "Number of columns and values should be the same" From 9aef842bc6159813a298a7c01be0ad85f18225f3 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 5 Apr 2023 09:36:02 +0000 Subject: [PATCH 08/43] Remove Python 3.9 support --- .github/workflows/cpu_ci.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cpu_ci.yml b/.github/workflows/cpu_ci.yml index 4cac3fb9..b9a21522 100644 --- a/.github/workflows/cpu_ci.yml +++ b/.github/workflows/cpu_ci.yml @@ -6,7 +6,7 @@ jobs: run-tests: strategy: matrix: - python-versions: [ 3.9, "3.10", "3.11" ] + python-versions: [ "3.10", "3.11" ] os: [ ubuntu-latest, macos-latest ] runs-on: ${{ matrix.os }} steps: diff --git a/pyproject.toml b/pyproject.toml index 49223931..f1de9148 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "eleuther-elk" description = "Keeping language models honest by directly eliciting knowledge encoded in their activations" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" keywords = ["nlp", "interpretability", "language-models", "explainable-ai"] license = {text = "MIT License"} dependencies = [ From 0851d4f5a9da72e1774086c2126f6b84adb18130 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 5 Apr 2023 09:50:06 +0000 Subject: [PATCH 09/43] Add Pandas to pyproject.toml --- elk/__main__.py | 3 +-- elk/extraction/extraction.py | 4 ++-- elk/training/reporter.py | 6 +++--- pyproject.toml | 2 ++ 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/elk/__main__.py b/elk/__main__.py index 4d84906d..5304f5aa 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -1,7 +1,6 @@ """Main entry point for `elk`.""" from dataclasses import dataclass -from typing import Union from simple_parsing import ArgumentParser @@ -14,7 +13,7 @@ class Command: """Some top-level command""" - command: Union[Elicit, Eval, Extract] + command: Elicit | Eval | Extract def execute(self): return self.command.execute() diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index f3a69170..35cd5600 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -3,7 +3,7 @@ import os from dataclasses import InitVar, dataclass from itertools import islice -from typing import Any, Iterable, Literal, Optional, Union +from typing import Any, Iterable, Literal, Optional import torch from datasets import ( @@ -82,7 +82,7 @@ def execute(self): def extract_hiddens( cfg: "Extract", *, - device: Union[str, torch.device] = "cpu", + device: str | torch.device = "cpu", split_type: Literal["train", "val"] = "train", rank: int = 0, world_size: int = 1, diff --git a/elk/training/reporter.py b/elk/training/reporter.py index cf3e21e2..9cdfb145 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Literal, NamedTuple, Optional, Union +from typing import Literal, NamedTuple, Optional import torch import torch.nn as nn @@ -151,11 +151,11 @@ def update(self, x_pos: Tensor, x_neg: Tensor) -> None: # TODO: These methods will do something fancier in the future @classmethod - def load(cls, path: Union[Path, str]): + def load(cls, path: Path | str): """Load a reporter from a file.""" return torch.load(path) - def save(self, path: Union[Path, str]): + def save(self, path: Path | str): # TODO: Save separate JSON and PT files for the reporter. torch.save(self, path) diff --git a/pyproject.toml b/pyproject.toml index f1de9148..16edd58e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,8 @@ dependencies = [ "numpy>=1.20.0", # This version is old, but it's needed for certain HF tokenizers to work. "protobuf==3.20.*", + # For logging. Indirectly required by datasets, but just to be safe we specify it here. + "pandas", # Basically any version should work as long as it supports the user's CUDA version "pynvml", # Doesn't really matter but before 1.0.0 there might be weird breaking changes From 207a3757d577b6c60fd0aea3fd58204309ca9b0c Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Wed, 5 Apr 2023 23:54:23 +0000 Subject: [PATCH 10/43] add code (contains still same device cuda error) --- elk/evaluation/evaluate.py | 155 ++++++++++++++++++++++++++++++++----- 1 file changed, 134 insertions(+), 21 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index c26aabf8..2b432ebc 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,30 +1,22 @@ -import csv -import os from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Literal, Optional, cast +from typing import Callable, Literal, Optional +from dataclasses import dataclass +from functools import partial +from pathlib import Path +from typing import Callable, Literal, Optional import torch -import torch.multiprocessing as mp from simple_parsing.helpers import Serializable, field -from torch import Tensor -from tqdm.auto import tqdm - -from datasets import DatasetDict -from elk.evaluation.evaluate_log import EvalLog -from elk.extraction.extraction import Extract -from elk.run import Run -from elk.training import Reporter -from ..files import elk_reporter_dir, memorably_named_dir -from ..training.preprocessing import normalize -from ..utils import ( - assert_type, - int16_to_float32, - select_train_val_splits, - select_usable_devices, -) +from evaluation.evaluate_log import EvalLog +from extraction.extraction import Extract +from files import elk_reporter_dir +from run import Run +from training import Reporter +from training.baseline import evaluate_baseline, load_baseline, train_baseline +from utils import select_usable_devices @dataclass @@ -51,7 +43,115 @@ class Eval(Serializable): debug: bool = False out_dir: Optional[Path] = None num_gpus: int = -1 + skip_baseline: bool = False + concatenated_layer_offset: int = 0 + def execute(self): + transfer_eval = elk_reporter_dir() / self.source / "transfer_eval" + + run = Evaluate(cfg=self, out_dir=transfer_eval) + run.evaluate() + + +@dataclass +class Evaluate(Run): + cfg: Eval + + def evaluate_reporter( + self, layer: int, devices: list[str], world_size: int = 1 + ) -> EvalLog: + """Evaluate a single reporter on a single layer.""" + device = self.get_device(devices, world_size) + + _, _, test_x0, test_x1, _, test_labels = self.prepare_data( + device, + layer, + ) + + experiment_dir = elk_reporter_dir() / self.cfg.source + + reporter_path = ( + experiment_dir / "reporters" / f"layer_{layer}.pt" + ) + reporter: Reporter = torch.load(reporter_path, map_location=device) + reporter.eval() + + test_result = reporter.score( + test_labels, + test_x0, + test_x1, + ) + + lr_dir = experiment_dir / "lr_models" + if not self.cfg.skip_baseline and lr_dir.exists(): + lr_model = load_baseline(lr_dir, layer) + lr_auroc, lr_acc = evaluate_baseline(lr_model, test_x0, test_x1, test_labels) + + print("transfer_eval", lr_auroc, lr_acc) + + # stats.lr_auroc = lr_auroc + # stats.lr_acc = lr_acc + # save_baseline(lr_dir, layer, lr_model) + + return EvalLog( + layer=layer, + eval_result=test_result, + ) + + def evaluate(self): + """Evaluate the reporter on all layers.""" + devices = select_usable_devices( + self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem + ) + + num_devices = len(devices) + func: Callable[[int], EvalLog] = partial( + self.evaluate_reporter, devices=devices, world_size=num_devices + ) + self.apply_to_layers( + func=func, + num_devices=num_devices, + to_csv_line=lambda item: item.to_csv_line(), + csv_columns=EvalLog.csv_columns(), + ) + +import torch +from simple_parsing.helpers import Serializable, field + +from evaluation.evaluate_log import EvalLog +from extraction.extraction import Extract +from files import elk_reporter_dir +from run import Run +from training import Reporter +from training.baseline import evaluate_baseline, load_baseline, train_baseline +from utils import select_usable_devices + + +@dataclass +class Eval(Serializable): + """ + Full specification of a reporter evaluation run. + + Args: + data: Config specifying hidden states on which the reporter will be evaluated. + source: The name of the source run directory + which contains the reporters directory. + normalization: The normalization method to use. Defaults to "meanonly". See + `elk.training.preprocessing.normalize()` for details. + num_gpus: The number of GPUs to use. Defaults to -1, which means + "use all available GPUs". + debug: When in debug mode, a useful log file is saved to the memorably-named + output directory. Defaults to False. + """ + + data: Extract + source: str = field(positional=True) + normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" + + debug: bool = False + out_dir: Optional[Path] = None + num_gpus: int = -1 + skip_baseline: bool = False concatenated_layer_offset: int = 0 def execute(self): @@ -76,8 +176,10 @@ def evaluate_reporter( layer, ) + experiment_dir = elk_reporter_dir() / self.cfg.source + reporter_path = ( - elk_reporter_dir() / self.cfg.source / "reporters" / f"layer_{layer}.pt" + experiment_dir / "reporters" / f"layer_{layer}.pt" ) reporter: Reporter = torch.load(reporter_path, map_location=device) reporter.eval() @@ -88,6 +190,17 @@ def evaluate_reporter( test_x1, ) + lr_dir = experiment_dir / "lr_models" + if not self.cfg.skip_baseline and lr_dir.exists(): + lr_model = load_baseline(lr_dir, layer) + lr_auroc, lr_acc = evaluate_baseline(lr_model, test_x0, test_x1, test_labels) + + print("transfer_eval", lr_auroc, lr_acc) + + # stats.lr_auroc = lr_auroc + # stats.lr_acc = lr_acc + # save_baseline(lr_dir, layer, lr_model) + return EvalLog( layer=layer, eval_result=test_result, From e7efcce7a26cc364f494dba8215fde117074fac5 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Fri, 7 Apr 2023 06:15:23 +0200 Subject: [PATCH 11/43] fix multiple cuda error, save evals to right folder + cleanup --- elk/evaluation/evaluate.py | 164 ++++------------------- elk/evaluation/evaluate_log.py | 27 ---- elk/run.py | 32 +++-- elk/training/baseline.py | 2 +- elk/training/reporter.py | 12 +- elk/training/train.py | 34 +++-- elk/training/train_log.py | 52 -------- elk/utils/csv.py | 39 ------ tests/test_log_csv_elements.py | 44 ------- tests/test_write_iterator_to_file.py | 186 --------------------------- 10 files changed, 63 insertions(+), 529 deletions(-) delete mode 100644 elk/evaluation/evaluate_log.py delete mode 100644 elk/training/train_log.py delete mode 100644 elk/utils/csv.py delete mode 100644 tests/test_log_csv_elements.py delete mode 100644 tests/test_write_iterator_to_file.py diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 2b432ebc..23e60dcd 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -2,21 +2,17 @@ from functools import partial from pathlib import Path from typing import Callable, Literal, Optional -from dataclasses import dataclass -from functools import partial -from pathlib import Path -from typing import Callable, Literal, Optional +import pandas as pd import torch from simple_parsing.helpers import Serializable, field -from evaluation.evaluate_log import EvalLog -from extraction.extraction import Extract -from files import elk_reporter_dir -from run import Run -from training import Reporter -from training.baseline import evaluate_baseline, load_baseline, train_baseline -from utils import select_usable_devices +from ..extraction.extraction import Extract +from ..files import create_output_directory, elk_reporter_dir, memorably_named_dir +from ..run import Run +from ..training import Reporter +from ..training.baseline import evaluate_baseline, load_baseline +from ..utils import select_usable_devices @dataclass @@ -48,8 +44,9 @@ class Eval(Serializable): def execute(self): transfer_eval = elk_reporter_dir() / self.source / "transfer_eval" + out_dir = memorably_named_dir(transfer_eval) - run = Evaluate(cfg=self, out_dir=transfer_eval) + run = Evaluate(cfg=self, out_dir=out_dir) run.evaluate() @@ -59,7 +56,7 @@ class Evaluate(Run): def evaluate_reporter( self, layer: int, devices: list[str], world_size: int = 1 - ) -> EvalLog: + ) -> pd.Series: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) @@ -70,9 +67,7 @@ def evaluate_reporter( experiment_dir = elk_reporter_dir() / self.cfg.source - reporter_path = ( - experiment_dir / "reporters" / f"layer_{layer}.pt" - ) + reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt" reporter: Reporter = torch.load(reporter_path, map_location=device) reporter.eval() @@ -82,129 +77,25 @@ def evaluate_reporter( test_x1, ) - lr_dir = experiment_dir / "lr_models" - if not self.cfg.skip_baseline and lr_dir.exists(): - lr_model = load_baseline(lr_dir, layer) - lr_auroc, lr_acc = evaluate_baseline(lr_model, test_x0, test_x1, test_labels) - - print("transfer_eval", lr_auroc, lr_acc) - - # stats.lr_auroc = lr_auroc - # stats.lr_acc = lr_acc - # save_baseline(lr_dir, layer, lr_model) - - return EvalLog( - layer=layer, - eval_result=test_result, - ) - - def evaluate(self): - """Evaluate the reporter on all layers.""" - devices = select_usable_devices( - self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem - ) - - num_devices = len(devices) - func: Callable[[int], EvalLog] = partial( - self.evaluate_reporter, devices=devices, world_size=num_devices - ) - self.apply_to_layers( - func=func, - num_devices=num_devices, - to_csv_line=lambda item: item.to_csv_line(), - csv_columns=EvalLog.csv_columns(), - ) - -import torch -from simple_parsing.helpers import Serializable, field - -from evaluation.evaluate_log import EvalLog -from extraction.extraction import Extract -from files import elk_reporter_dir -from run import Run -from training import Reporter -from training.baseline import evaluate_baseline, load_baseline, train_baseline -from utils import select_usable_devices - - -@dataclass -class Eval(Serializable): - """ - Full specification of a reporter evaluation run. - - Args: - data: Config specifying hidden states on which the reporter will be evaluated. - source: The name of the source run directory - which contains the reporters directory. - normalization: The normalization method to use. Defaults to "meanonly". See - `elk.training.preprocessing.normalize()` for details. - num_gpus: The number of GPUs to use. Defaults to -1, which means - "use all available GPUs". - debug: When in debug mode, a useful log file is saved to the memorably-named - output directory. Defaults to False. - """ - - data: Extract - source: str = field(positional=True) - normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" - - debug: bool = False - out_dir: Optional[Path] = None - num_gpus: int = -1 - skip_baseline: bool = False - concatenated_layer_offset: int = 0 - - def execute(self): - transfer_eval = elk_reporter_dir() / self.source / "transfer_eval" - - run = Evaluate(cfg=self, out_dir=transfer_eval) - run.evaluate() - - -@dataclass -class Evaluate(Run): - cfg: Eval - - def evaluate_reporter( - self, layer: int, devices: list[str], world_size: int = 1 - ) -> EvalLog: - """Evaluate a single reporter on a single layer.""" - device = self.get_device(devices, world_size) - - _, _, test_x0, test_x1, _, test_labels = self.prepare_data( - device, - layer, - ) - - experiment_dir = elk_reporter_dir() / self.cfg.source - - reporter_path = ( - experiment_dir / "reporters" / f"layer_{layer}.pt" - ) - reporter: Reporter = torch.load(reporter_path, map_location=device) - reporter.eval() - - test_result = reporter.score( - test_labels, - test_x0, - test_x1, + stats_row = pd.Series( + { + "layer": layer, + **test_result._asdict(), + } ) lr_dir = experiment_dir / "lr_models" if not self.cfg.skip_baseline and lr_dir.exists(): lr_model = load_baseline(lr_dir, layer) - lr_auroc, lr_acc = evaluate_baseline(lr_model, test_x0, test_x1, test_labels) + lr_model.eval() + lr_auroc, lr_acc = evaluate_baseline( + lr_model.cuda(), test_x0.cuda(), test_x1.cuda(), test_labels + ) - print("transfer_eval", lr_auroc, lr_acc) + stats_row["lr_auroc"] = lr_auroc + stats_row["lr_acc"] = lr_acc - # stats.lr_auroc = lr_auroc - # stats.lr_acc = lr_acc - # save_baseline(lr_dir, layer, lr_model) - - return EvalLog( - layer=layer, - eval_result=test_result, - ) + return stats_row def evaluate(self): """Evaluate the reporter on all layers.""" @@ -213,12 +104,7 @@ def evaluate(self): ) num_devices = len(devices) - func: Callable[[int], EvalLog] = partial( + func: Callable[[int], pd.Series] = partial( self.evaluate_reporter, devices=devices, world_size=num_devices ) - self.apply_to_layers( - func=func, - num_devices=num_devices, - to_csv_line=lambda item: item.to_csv_line(), - csv_columns=EvalLog.csv_columns(), - ) + self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/evaluation/evaluate_log.py b/elk/evaluation/evaluate_log.py deleted file mode 100644 index 2cb9b485..00000000 --- a/elk/evaluation/evaluate_log.py +++ /dev/null @@ -1,27 +0,0 @@ -from dataclasses import dataclass - -from elk.training.reporter import EvalResult - - -@dataclass -class EvalLog: - """The result of running eval on a layer of a dataset""" - - layer: int - eval_result: EvalResult - - @staticmethod - def csv_columns() -> list[str]: - return ["layer", "acc", "cal_acc", "auroc", "ece"] - - def to_csv_line(self) -> list[str]: - items = [ - self.layer, - self.eval_result.acc, - self.eval_result.cal_acc, - self.eval_result.auroc, - self.eval_result.ece, - ] - return [ - f"{item:.4f}" if isinstance(item, float) else str(item) for item in items - ] diff --git a/elk/run.py b/elk/run.py index ea6606ec..39dd4288 100644 --- a/elk/run.py +++ b/elk/run.py @@ -6,16 +6,17 @@ from typing import TYPE_CHECKING, Callable, Iterator, Optional, Union import numpy as np +import pandas as pd import torch import torch.multiprocessing as mp +from datasets import DatasetDict from torch import Tensor from tqdm import tqdm -from datasets import DatasetDict from elk.extraction.extraction import extract from elk.files import create_output_directory, save_config, save_meta +from elk.logging import save_debug_log from elk.training.preprocessing import normalize -from elk.utils.csv import Log, write_iterator_to_file from elk.utils.data_utils import get_layers, select_train_val_splits from elk.utils.typing import assert_type, int16_to_float32 @@ -94,10 +95,8 @@ def concatenate(self, layers): def apply_to_layers( self, - func: Callable[[int], Log], + func: Callable[[int], pd.Series], num_devices: int, - to_csv_line: Callable[[Log], list[str]], - csv_columns: list[str], ): """Apply a function to each layer of the dataset in parallel and writes the results to a CSV file. @@ -120,15 +119,14 @@ def apply_to_layers( # Should we write to different CSV files for elicit vs eval? with mp.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: mapper = pool.imap_unordered if num_devices > 1 else map - iterator: Iterator[Log] = tqdm( # type: ignore - mapper(func, layers), total=len(layers) - ) - write_iterator_to_file( - iterator=iterator, - file=f, - debug=self.cfg.debug, - dataset=self.dataset, - out_dir=self.out_dir, - csv_columns=csv_columns, - to_csv_line=to_csv_line, - ) + row_buf = [] + + try: + for row in tqdm(mapper(func, layers), total=len(layers)): + row_buf.append(row) + finally: + # Make sure the CSV is written even if we crash or get interrupted + df = pd.DataFrame(row_buf).sort_values(by="layer") + df.to_csv(f, index=False) + if self.cfg.debug: + save_debug_log(self.dataset, self.out_dir) diff --git a/elk/training/baseline.py b/elk/training/baseline.py index 2c9542a6..b964bc63 100644 --- a/elk/training/baseline.py +++ b/elk/training/baseline.py @@ -1,6 +1,6 @@ import pickle from pathlib import Path -from typing import Tuple +from typing import NamedTuple, Tuple import torch from sklearn.metrics import accuracy_score, roc_auc_score diff --git a/elk/training/reporter.py b/elk/training/reporter.py index ea8a4406..cf3e21e2 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -1,16 +1,18 @@ """An ELK reporter network.""" -from ..calibration import CalibrationError -from .classifier import Classifier from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from simple_parsing.helpers import Serializable -from sklearn.metrics import roc_auc_score -from torch import Tensor from typing import Literal, NamedTuple, Optional, Union + import torch import torch.nn as nn +from simple_parsing.helpers import Serializable +from sklearn.metrics import roc_auc_score +from torch import Tensor + +from ..calibration import CalibrationError +from .classifier import Classifier class EvalResult(NamedTuple): diff --git a/elk/training/train.py b/elk/training/train.py index a6ecafdc..04dd1614 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Callable, Literal, Optional +import pandas as pd import torch from simple_parsing import Serializable, field, subgroups from torch import Tensor @@ -18,7 +19,6 @@ from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig from .reporter import OptimConfig, Reporter, ReporterConfig -from .train_log import ElicitLog @dataclass @@ -77,7 +77,7 @@ def train_reporter( layer: int, devices: list[str], world_size: int = 1, - ) -> ElicitLog: + ) -> pd.Series: """Train a single reporter on a single layer.""" self.make_reproducible(seed=self.cfg.net.seed + layer) @@ -103,11 +103,14 @@ def train_reporter( ) reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - stats: ElicitLog = ElicitLog( - layer=layer, - pseudo_auroc=pseudo_auroc, - train_loss=train_loss, - eval_result=val_result, + + stats_row = pd.Series( + { + "layer": layer, + "pseudo_auroc": pseudo_auroc, + "train_loss": train_loss, + **val_result._asdict(), + } ) if not self.cfg.skip_baseline: @@ -115,14 +118,14 @@ def train_reporter( lr_auroc, lr_acc = evaluate_baseline(lr_model, val_x0, val_x1, val_labels) - stats.lr_auroc = lr_auroc - stats.lr_acc = lr_acc + stats_row["lr_auroc"] = lr_auroc + stats_row["lr_acc"] = lr_acc save_baseline(lr_dir, layer, lr_model) with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: torch.save(reporter, file) - return stats + return stats_row def get_pseudo_auroc( self, layer: int, x0: Tensor, x1: Tensor, val_x0: Tensor, val_x1: Tensor @@ -146,14 +149,7 @@ def train(self): """Train a reporter on each layer of the network.""" devices = select_usable_devices(self.cfg.num_gpus) num_devices = len(devices) - func: Callable[[int], ElicitLog] = partial( + func: Callable[[int], pd.Series] = partial( self.train_reporter, devices=devices, world_size=num_devices ) - self.apply_to_layers( - func=func, - num_devices=num_devices, - to_csv_line=lambda item: item.to_csv_line( - skip_baseline=self.cfg.skip_baseline - ), - csv_columns=ElicitLog.csv_columns(self.cfg.skip_baseline), - ) + self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/training/train_log.py b/elk/training/train_log.py deleted file mode 100644 index 37e5a392..00000000 --- a/elk/training/train_log.py +++ /dev/null @@ -1,52 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -from elk.training.reporter import EvalResult - - -@dataclass -class ElicitLog: - """The result of running elicit on a layer of a dataset""" - - layer: int - train_loss: float - eval_result: EvalResult - pseudo_auroc: float - # Only available if reporting baseline - lr_auroc: Optional[float] = None - # Only available if reporting baseline - lr_acc: Optional[float] = None - - @staticmethod - def csv_columns(skip_baseline: bool) -> list[str]: - """Return a CSV header with the column names.""" - cols = [ - "layer", - "pseudo_auroc", - "train_loss", - "acc", - "cal_acc", - "auroc", - "ece", - ] - if not skip_baseline: - cols += ["lr_auroc", "lr_acc"] - return cols - - def to_csv_line(self, skip_baseline: bool) -> list[str]: - """Return a CSV line with the evaluation results.""" - items = [ - self.layer, - self.pseudo_auroc, - self.train_loss, - self.eval_result.acc, - self.eval_result.cal_acc, - self.eval_result.auroc, - self.eval_result.ece, - ] - if not skip_baseline: - items += [self.lr_auroc, self.lr_acc] - - return [ - f"{item:.4f}" if isinstance(item, float) else str(item) for item in items - ] diff --git a/elk/utils/csv.py b/elk/utils/csv.py deleted file mode 100644 index c13cf1d6..00000000 --- a/elk/utils/csv.py +++ /dev/null @@ -1,39 +0,0 @@ -import csv -from pathlib import Path -from typing import Iterator, Callable, TextIO, TypeVar - -from datasets import DatasetDict - -from elk.evaluation.evaluate_log import EvalLog -from elk.logging import save_debug_log -from elk.training.train_log import ElicitLog - -"""A generic log type that contains a layer field -The layer field is used to sort the logs by layer.""" -Log = TypeVar("Log", EvalLog, ElicitLog) - - -def write_iterator_to_file( - iterator: Iterator[Log], - csv_columns: list[str], - to_csv_line: Callable[[Log], list[str]], - file: TextIO, - debug: bool, - dataset: DatasetDict, - out_dir: Path, -) -> None: - row_buf = [] - writer = csv.writer(file) - # write a single line - writer.writerow(csv_columns) - try: - for row in iterator: - row_buf.append(row) - finally: - # Make sure the CSV is written even if we crash or get interrupted - sorted_by_layer = sorted(row_buf, key=lambda x: x.layer) - for row in sorted_by_layer: - row = to_csv_line(row) - writer.writerow(row) - if debug: - save_debug_log(dataset, out_dir) diff --git a/tests/test_log_csv_elements.py b/tests/test_log_csv_elements.py deleted file mode 100644 index 59cb4352..00000000 --- a/tests/test_log_csv_elements.py +++ /dev/null @@ -1,44 +0,0 @@ -from elk.evaluation.evaluate_log import EvalLog -from elk.training.reporter import EvalResult -from elk.training.train_log import ElicitLog - - -def test_eval_log_csv_number_elements(): - log = EvalLog( - layer=1, - eval_result=EvalResult( - acc=1.0, - cal_acc=1.0, - auroc=1.0, - ece=1.0, - ), - ) - csv_columns = EvalLog.csv_columns() - csv_values = log.to_csv_line() - assert len(csv_columns) == len( - csv_values - ), "Number of columns and values should be the same" - - -def test_elicit_log_csv_number_elements(): - log = ElicitLog( - layer=1, - train_loss=1.0, - eval_result=EvalResult( - acc=0.0, - ece=0.0, - cal_acc=0.0, - auroc=0.0, - ), - pseudo_auroc=0.0, - ) - csv_columns = ElicitLog.csv_columns(skip_baseline=True) - csv_values = log.to_csv_line(skip_baseline=True) - assert len(csv_columns) == len( - csv_values - ), "Number of columns and values should be the same" - csv_columns_not_skipped = ElicitLog.csv_columns(skip_baseline=False) - csv_values_not_skipped = log.to_csv_line(skip_baseline=False) - assert len(csv_columns_not_skipped) == len( - csv_values_not_skipped - ), "Number of columns and values should be the same" diff --git a/tests/test_write_iterator_to_file.py b/tests/test_write_iterator_to_file.py deleted file mode 100644 index a40d8e2f..00000000 --- a/tests/test_write_iterator_to_file.py +++ /dev/null @@ -1,186 +0,0 @@ -import csv -import time -from pathlib import Path -from typing import Iterator -import multiprocessing as mp - - -from datasets import DatasetDict - -from elk.utils.csv import write_iterator_to_file -from elk.training.reporter import EvalResult -from elk.training.train_log import ElicitLog - - -def test_write_iterator_to_file(tmp_path: Path): - items: list[ElicitLog] = [ - ElicitLog( - layer=1, - train_loss=1.0, - eval_result=EvalResult( - acc=0.0, - ece=0.0, - cal_acc=0.0, - auroc=0.0, - ), - pseudo_auroc=0.0, - ) - ] - iterator = iter(items) - csv_columns = ElicitLog.csv_columns(skip_baseline=True) - - def to_csv_line(x): - return x.to_csv_line(skip_baseline=True) - - # Write the CSV file - with open(tmp_path / "test.csv", "w") as f: - write_iterator_to_file( - iterator=iterator, - file=f, - debug=False, - dataset=DatasetDict(), - out_dir=tmp_path, - csv_columns=csv_columns, - to_csv_line=to_csv_line, - ) - # Read the CSV file - with open(tmp_path / "test.csv", "r") as f: - reader = csv.reader(f) - # read all to lines - lines = list(reader) - # assert that the first line is the header - assert lines[0] == csv_columns - # assert that the second line is the data - assert lines[1] == to_csv_line(items[0]) - - -def test_write_iterator_to_file_crash(tmp_path: Path): - first_layer_log = ElicitLog( - layer=1, - train_loss=1.0, - eval_result=EvalResult( - acc=0.0, - ece=0.0, - cal_acc=0.0, - auroc=0.0, - ), - pseudo_auroc=0.0, - ) - - second_layer_log = ElicitLog( - layer=2, - train_loss=1.0, - eval_result=EvalResult( - acc=0.0, - ece=0.0, - cal_acc=0.0, - auroc=0.0, - ), - pseudo_auroc=0.0, - ) - - def iterator() -> Iterator[ElicitLog]: - for i in range(3): - if i == 0: - yield first_layer_log - elif i == 1: - yield second_layer_log - # on the third iteration, raise an ValueError - # We should still be able to write the first two layers - if i == 2: - raise ValueError() - - csv_columns = ElicitLog.csv_columns(skip_baseline=True) - - def to_csv_line(x): - return x.to_csv_line(skip_baseline=True) - - # Write the CSV file - try: - with open(tmp_path / "test.csv", "w") as f: - write_iterator_to_file( - iterator=iterator(), - file=f, - debug=False, - dataset=DatasetDict(), - out_dir=tmp_path, - csv_columns=csv_columns, - to_csv_line=to_csv_line, - ) - except ValueError: - # We expect a ValueError to be raised, - # and don't want to fail the test - pass - - # Read the CSV file - with open(tmp_path / "test.csv", "r") as f: - reader = csv.reader(f) - # read all to lines - lines = list(reader) - # assert that the first line is the header - assert lines[0] == csv_columns - # assert that the second line has the first layer - assert lines[1] == to_csv_line(first_layer_log) - # assert that the third line has the second layer - assert lines[2] == to_csv_line(second_layer_log) - - -def log_function(layer: int) -> ElicitLog: - """ - raise an error on the second layer - This is a top-level function so that it can be pickled - for multiprocessing - """ - if layer == 2: - # let the other processes finish first - time.sleep(3) - # crash the process - raise ValueError() - return ElicitLog( - layer=layer, - train_loss=1.0, - eval_result=EvalResult( - acc=0.0, - ece=0.0, - cal_acc=0.0, - auroc=0.0, - ), - pseudo_auroc=0.0, - ) - - -def test_write_iterator_crash_multiprocessing(tmp_path: Path): - processes = 3 - csv_columns = ElicitLog.csv_columns(skip_baseline=True) - - def to_csv_line(x): - return x.to_csv_line(skip_baseline=True) - - try: - with mp.Pool(processes) as pool, open(tmp_path / "eval.csv", "w") as f: - layers = [1, 2, 3] - iterator = pool.imap_unordered(log_function, layers) - write_iterator_to_file( - iterator=iterator, - file=f, - debug=False, - dataset=DatasetDict(), - out_dir=tmp_path, - csv_columns=csv_columns, - to_csv_line=to_csv_line, - ) - except ValueError: - # We expect a ValueError to be raised, - # and don't want to fail the test - pass - # We should still have results for layer 1, 3, even though layer 2 failed - with open(tmp_path / "eval.csv", "r") as f: - reader = csv.reader(f) - # read all to lines - lines = list(reader) - # assert that the first line is the header - assert lines[0] == csv_columns - # assert that the second line has the first layer - assert lines[1] == to_csv_line(log_function(1)) - # assert that the third line has the third layer - assert lines[2] == to_csv_line(log_function(3)) From 4f8bdc5e1e52f983205a78526bf42056b4314747 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Apr 2023 04:20:31 +0000 Subject: [PATCH 12/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/run.py | 1 - elk/training/baseline.py | 2 +- elk/training/train.py | 3 +-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/elk/run.py b/elk/run.py index caeba4ba..76719a53 100644 --- a/elk/run.py +++ b/elk/run.py @@ -17,7 +17,6 @@ from elk.files import create_output_directory, save_config, save_meta from elk.logging import save_debug_log from elk.training.preprocessing import normalize - from elk.utils.data_utils import get_layers, select_train_val_splits from elk.utils.typing import assert_type, int16_to_float32 diff --git a/elk/training/baseline.py b/elk/training/baseline.py index b964bc63..2c9542a6 100644 --- a/elk/training/baseline.py +++ b/elk/training/baseline.py @@ -1,6 +1,6 @@ import pickle from pathlib import Path -from typing import NamedTuple, Tuple +from typing import Tuple import torch from sklearn.metrics import accuracy_score, roc_auc_score diff --git a/elk/training/train.py b/elk/training/train.py index 17b61e80..04dd1614 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -13,8 +13,7 @@ from ..extraction.extraction import Extract from ..run import Run -from ..training.baseline import (evaluate_baseline, save_baseline, - train_baseline) +from ..training.baseline import evaluate_baseline, save_baseline, train_baseline from ..utils import select_usable_devices from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig From 9ca72babd22726c59fdc6c9faee818c87d018d21 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Fri, 7 Apr 2023 04:48:46 +0000 Subject: [PATCH 13/43] Fix bug noticed by Waree --- elk/extraction/extraction.py | 4 +++- elk/run.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 35cd5600..a5c44baa 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -112,6 +112,8 @@ def extract_hiddens( cfg.model, truncation_side="left", verbose=False ) has_lm_preds = is_autoregressive(model.config) + if has_lm_preds and rank == 0: + print("Model has language model head, will store predictions.") # Iterating over questions layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) @@ -182,7 +184,7 @@ def extract_hiddens( outputs = model(**inputs, output_hidden_states=True) # Compute the log probability of the answer tokens if available - if type(outputs).__name__.startswith("CausalLMOutput"): + if has_lm_preds: start, end = convert_span( offsets, (answer_start, answer_start + len(choice["answer"])) ) diff --git a/elk/run.py b/elk/run.py index 533304eb..1d0da05f 100644 --- a/elk/run.py +++ b/elk/run.py @@ -89,7 +89,8 @@ def prepare_data( val_x0, val_x1 = val_h.unbind(dim=-2) with self.dataset.formatted_as("numpy"): - val_lm_preds = val["model_preds"] if "model_preds" in val else None + has_preds = "model_preds" in val.features + val_lm_preds = val["model_preds"] if has_preds else None return x0, x1, val_x0, val_x1, train_labels, val_labels, val_lm_preds From 713a251f3601067f43f9e55d1d6d277e4875e1b4 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Fri, 7 Apr 2023 22:51:05 +0000 Subject: [PATCH 14/43] Add sanity check to load_prompts and refactor binarize --- elk/extraction/extraction.py | 7 ++++++- elk/extraction/prompt_loading.py | 17 ++++++++++++----- elk/utils/data_utils.py | 7 ++++--- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index a5c44baa..5bea9fe8 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -152,7 +152,6 @@ def extract_hiddens( # Iterate over answers for j, choice in enumerate(record): text = choice["text"] - variant_inputs.append(text) # TODO: Do something smarter than "rindex" here. Really we want to # get the span of the answer directly from Jinja, but that doesn't @@ -167,6 +166,8 @@ def extract_hiddens( else: target = None + # Record the EXACT string we fed to the model + variant_inputs.append(text) inputs = tokenizer( text, return_offsets_mapping=True, @@ -318,6 +319,10 @@ def get_splits() -> SplitDict: for (split_name, split_info) in get_splits().items() } + import multiprocess as mp + + mp.set_start_method("spawn") # type: ignore[attr-defined] + ds = dict() for split, builder in builders.items(): builder.download_and_prepare(num_proc=len(devices)) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index a494be5a..7375af59 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -1,3 +1,4 @@ +from collections import Counter from dataclasses import dataclass from random import Random from typing import Any, Iterator, Literal, Optional @@ -202,6 +203,7 @@ def _convert_to_prompts( fewshot_iter: Optional[Iterator[list[dict]]] = None, ) -> dict[str, Any]: """Prompt-generating function to pass to `IterableDataset.map`.""" + label = assert_type(int, example[label_column]) prompts = [] templates = list(prompter.templates.values()) if num_variants < len(templates): @@ -212,15 +214,14 @@ def qa_cat(q: str, a: str) -> str: sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " " return f"{q}{sep}{a}" if a and not a.isspace() else q - new_label = rng.choice([0, 1]) if num_classes > 2 else example[label_column] + # For sanity checking that prompts are unique + prompt_counter = Counter() for template in templates: choices = [] if num_classes > 2: - template = binarize( - template, example[label_column], assert_type(int, new_label), rng - ) + template, label = binarize(template, label, rng) for answer_idx in range(2): fake_example = example.copy() @@ -228,6 +229,7 @@ def qa_cat(q: str, a: str) -> str: q, a = template.apply(fake_example) text = qa_cat(q, a) + prompt_counter[text] += 1 if fewshot_iter is not None: # Infinite iterator so we don't need to worry about StopIteration @@ -248,8 +250,13 @@ def qa_cat(q: str, a: str) -> str: prompts.append(choices) + # Sanity check: variants should be unique + ((maybe_dup, dup_count),) = prompt_counter.most_common(1) + if dup_count > 1: + raise ValueError(f'Prompt duplicated {dup_count} times! "{maybe_dup}"') + return dict( - label=new_label, + label=label, prompts=prompts, template_names=prompter.all_template_names, ) diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index f8129145..393ee6a6 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -111,7 +111,7 @@ def get_layers(ds: DatasetDict) -> List[int]: return layers -def binarize(template: Template, label: int, new_label: int, rng: Random) -> Template: +def binarize(template: Template, label: int, rng: Random) -> tuple[Template, int]: """Binarize a template with >2 answer choices, returning a new template and label. Returns: @@ -132,11 +132,12 @@ def binarize(template: Template, label: int, new_label: int, rng: Random) -> Tem true = answer_choices[label] false = rng.choice([c for c in answer_choices if c != true]) - assert new_label in (0, 1) + # What order are we going to present the answer choices in? + new_label = rng.choice([0, 1]) new_template = copy.deepcopy(template) new_template.answer_choices = ( f"{false} ||| {true}" if new_label else f"{true} ||| {false}" ) - return new_template + return new_template, new_label From 0c35bc773b0cddf4cd3b8bc339f1da840d3e3c9c Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 8 Apr 2023 22:17:28 +0000 Subject: [PATCH 15/43] Changing a ton of stuff --- elk/evaluation/evaluate.py | 11 +- elk/extraction/balanced_sampler.py | 48 ++-- elk/extraction/extraction.py | 34 +-- elk/extraction/generator.py | 4 +- elk/extraction/prompt_loading.py | 6 +- elk/metrics.py | 79 ++++++ .../templates/sst2/templates.yaml | 224 ++++++++++++++++++ elk/run.py | 15 +- elk/training/baseline.py | 55 ++--- elk/training/ccs_reporter.py | 3 + elk/training/classifier.py | 17 +- elk/training/eigen_reporter.py | 132 ++++++----- elk/training/reporter.py | 96 ++++---- elk/training/train.py | 73 +++--- elk/utils/__init__.py | 8 +- elk/utils/data_utils.py | 7 +- pyproject.toml | 1 + tests/test_eigen_reporter.py | 12 +- tests/test_samplers.py | 2 +- 19 files changed, 582 insertions(+), 245 deletions(-) create mode 100644 elk/metrics.py create mode 100644 elk/promptsource/templates/sst2/templates.yaml diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index b9b8e676..a1c06440 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -60,7 +60,7 @@ def evaluate_reporter( """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) - _, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data( + _, test_h, _, test_labels, _ = self.prepare_data( device, layer, ) @@ -71,12 +71,7 @@ def evaluate_reporter( reporter: Reporter = torch.load(reporter_path, map_location=device) reporter.eval() - test_result = reporter.score( - test_labels, - test_x0, - test_x1, - ) - + test_result = reporter.score(test_labels, test_h) stats_row = pd.Series( { "layer": layer, @@ -89,7 +84,7 @@ def evaluate_reporter( lr_model = load_baseline(lr_dir, layer) lr_model.eval() lr_auroc, lr_acc = evaluate_baseline( - lr_model.cuda(), test_x0.cuda(), test_x1.cuda(), test_labels + lr_model.cuda(), test_h.cuda(), test_labels ) stats_row["lr_auroc"] = lr_auroc diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 2ea4815e..ec011a47 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -1,4 +1,5 @@ from collections import deque +from dataclasses import dataclass, field from itertools import cycle from random import Random from typing import Iterable, Iterator, Optional @@ -11,39 +12,42 @@ from ..utils.typing import assert_type +@dataclass class BalancedSampler(TorchIterableDataset): """ - Approximately balances a binary classification dataset in a streaming fashion. - - Args: - dataset (IterableDataset): The HuggingFace IterableDataset to balance. - label_col (Optional[str], optional): The name of the column containing the - binary label. If not provided, the label column will be inferred from - the dataset features. Defaults to None. - buffer_size (int, optional): The total buffer size to use for balancing the - dataset. This value should be divisible by 2, as it will be equally - divided between the two binary label values (0 and 1). Defaults to 1000. + A sampler that approximately balances a multi-class classification dataset in a + streaming fashion. + + Attributes: + data: The input dataset to balance. + num_classes: The total number of classes expected in the data. + buffer_size: The total buffer size to use for balancing the dataset. Each class + will have its own buffer with this size. """ - def __init__(self, data: Iterable[dict], buffer_size: int = 1000): - self.data = data + data: Iterable[dict] + num_classes: int + buffer_size: int = 1000 + buffers: dict[int, deque[dict]] = field(default_factory=dict, init=False) - self.neg_buffer = deque(maxlen=buffer_size) - self.pos_buffer = deque(maxlen=buffer_size) + def __post_init__(self): + # Initialize empty buffers + self.buffers = { + label: deque(maxlen=self.buffer_size) for label in range(self.num_classes) + } def __iter__(self): for sample in self.data: label = sample["label"] - # Add the sample to the appropriate buffer - if label == 0: - self.neg_buffer.append(sample) - else: - self.pos_buffer.append(sample) + # Add the sample to the buffer for its class label + self.buffers[label].append(sample) - while self.neg_buffer and self.pos_buffer: - yield self.neg_buffer.popleft() - yield self.pos_buffer.popleft() + # Check if all buffers have at least one sample + while all(len(buffer) > 0 for buffer in self.buffers.values()): + # Yield one sample from each buffer in a round-robin fashion + for buf in self.buffers.values(): + yield buf.popleft() class FewShotSampler: diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 5bea9fe8..bc541aef 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -7,8 +7,8 @@ import torch from datasets import ( + Array2D, Array3D, - ClassLabel, DatasetDict, Features, Sequence, @@ -22,11 +22,12 @@ from transformers import AutoConfig, AutoTokenizer from transformers.modeling_outputs import Seq2SeqLMOutput -# import torch.nn.functional as F from ..utils import ( assert_type, convert_span, float32_to_int16, + infer_label_column, + infer_num_classes, instantiate_model, is_autoregressive, select_train_val_splits, @@ -103,7 +104,7 @@ def extract_hiddens( stream=cfg.prompts.stream, rank=rank, world_size=world_size, - ) # this dataset is already sharded, but hasn't been truncated to max_examples + ) # this dataset is already sharded, buqt hasn't been truncated to max_examples model = instantiate_model( cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32 @@ -125,12 +126,14 @@ def extract_hiddens( if rank == world_size - 1: max_examples += global_max_examples % world_size - for example in islice(BalancedSampler(prompt_ds), max_examples): + for example in islice(BalancedSampler(prompt_ds, 3), max_examples): num_variants = len(example["prompts"]) + num_choices = len(example["prompts"][0]) + hidden_dict = { f"hidden_{layer_idx}": torch.empty( num_variants, - 2, # contrast pair + num_choices, model.config.hidden_size, device=device, dtype=torch.int16, @@ -139,7 +142,7 @@ def extract_hiddens( } lm_preds = torch.empty( num_variants, - 2, # contrast pair + num_choices, device=device, dtype=torch.float32, ) @@ -232,8 +235,7 @@ def extract_hiddens( **hidden_dict, ) if has_lm_preds: - # We only need the probability of the positive example since this is binary - out_record["model_preds"] = lm_preds.softmax(dim=-1)[..., 1] + out_record["model_preds"] = lm_preds.softmax(dim=-1) yield out_record @@ -271,10 +273,14 @@ def get_splits() -> SplitDict: ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") info = get_dataset_config_info(ds_name, config_name or None) + ds_features = assert_type(Features, info.features) + label_col = infer_label_column(ds_features) + num_classes = infer_num_classes(ds_features[label_col]) + layer_cols = { f"hidden_{layer}": Array3D( dtype="int16", - shape=(num_variants, 2, model_cfg.hidden_size), + shape=(num_variants, num_classes, model_cfg.hidden_size), ) for layer in cfg.layers or range(model_cfg.num_hidden_layers) } @@ -283,11 +289,10 @@ def get_splits() -> SplitDict: Value(dtype="string"), length=num_variants, ), - "label": ClassLabel(names=["neg", "pos"]), + "label": Value(dtype="int64"), "text_inputs": Sequence( Sequence( Value(dtype="string"), - length=2, ), length=num_variants, ), @@ -295,9 +300,9 @@ def get_splits() -> SplitDict: # Only add model_preds if the model is an autoregressive model if is_autoregressive(model_cfg): - other_cols["model_preds"] = Sequence( - Value(dtype="float32"), - length=num_variants, + other_cols["model_preds"] = Array2D( + shape=(num_variants, num_classes), + dtype="float32", ) devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem) @@ -318,7 +323,6 @@ def get_splits() -> SplitDict: ) for (split_name, split_info) in get_splits().items() } - import multiprocess as mp mp.set_start_method("spawn") # type: ignore[attr-defined] diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index fb4d03bc..7e937703 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import datasets from datasets.splits import NamedSplit @@ -20,7 +20,7 @@ class _SplitGenerator: name: str split_info: datasets.SplitInfo - gen_kwargs: Dict = field(default_factory=dict) + gen_kwargs: dict = field(default_factory=dict) def __post_init__(self): self.name = str(self.name) # Make sure we convert NamedSplits in strings diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 7375af59..14721436 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -14,7 +14,6 @@ from ..promptsource import DatasetTemplates from ..utils import ( assert_type, - binarize, infer_label_column, infer_num_classes, select_train_val_splits, @@ -220,10 +219,7 @@ def qa_cat(q: str, a: str) -> str: for template in templates: choices = [] - if num_classes > 2: - template, label = binarize(template, label, rng) - - for answer_idx in range(2): + for answer_idx in range(num_classes): fake_example = example.copy() fake_example[label_column] = answer_idx diff --git a/elk/metrics.py b/elk/metrics.py new file mode 100644 index 00000000..a57b37bd --- /dev/null +++ b/elk/metrics.py @@ -0,0 +1,79 @@ +from functools import partial +from typing import Literal + +import torch +from sklearn.metrics import average_precision_score, roc_auc_score +from torch import Tensor + + +def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: + """ + Convert a tensor of class labels to a one-hot representation. + + Args: + labels (Tensor): A tensor of class labels of shape (N,). + n_classes (int): The total number of unique classes. + + Returns: + Tensor: A one-hot representation tensor of shape (N, n_classes). + """ + one_hot_labels = torch.zeros(labels.size(0), n_classes, dtype=torch.float32) + return one_hot_labels.scatter_(1, labels.unsqueeze(1).long(), 1) + + +def accuracy(y_true: Tensor, y_pred: Tensor) -> float: + """ + Compute the accuracy of a classification model. + + Args: + y_true: Ground truth tensor of shape (N,). + y_pred: Predicted class tensor of shape (N,) or (N, n_classes). + + Returns: + float: Accuracy of the model. + """ + # Check if binary or multi-class classification + if len(y_pred.shape) == 1: + hard_preds = y_pred > 0.5 + else: + hard_preds = y_pred.argmax(-1) + + return hard_preds.eq(y_true).float().mean().item() + + +def mean_auc(y_true: Tensor, y_scores: Tensor, curve: Literal["roc", "pr"]) -> float: + """ + Compute the mean area under the receiver operating curve (AUROC) or + precision-recall curve (average precision or mAP) for binary or multi-class + classification problems. + + Args: + y_true: Ground truth tensor of shape (N,) or (N, n_classes). + y_scores: Predicted probability tensor of shape (N,) for binary + or (N, n_classes) for multi-class. + curve: Type of curve to compute the mean AUC. Either 'pr' for + precision-recall curve or 'roc' for receiver operating + characteristic curve. Defaults to 'pr'. + + Returns: + float: Either mean AUROC or mean average precision (mAP). + """ + score_fn = { + "pr": average_precision_score, + "roc": partial(roc_auc_score, multi_class="ovo"), + }.get(curve, None) + + if score_fn is None: + raise ValueError("Invalid curve type. Supported values are 'pr' and 'roc'.") + + if len(y_scores.shape) == 1 or y_scores.shape[1] == 1: + return float(score_fn(y_true, y_scores.squeeze(1))) + else: + n_classes = y_scores.shape[1] + y_true_one_hot = to_one_hot(y_true, n_classes) + + return score_fn(y_true_one_hot, y_scores) + # return np.array([ + # score_fn(y_true_one_hot[:, i], y_scores[:, i]) + # for i in range(n_classes) + # ]).mean() diff --git a/elk/promptsource/templates/sst2/templates.yaml b/elk/promptsource/templates/sst2/templates.yaml new file mode 100644 index 00000000..92d82916 --- /dev/null +++ b/elk/promptsource/templates/sst2/templates.yaml @@ -0,0 +1,224 @@ +# Taken from "Overthinking the Truth" by Halawi et al. (2023) +# Translated into Promptsource YAML by GPT-4 +dataset: sst-2 +subset: default +templates: + a1: !Template + answer_choices: Negative ||| Positive + id: a1 + jinja: 'Review: {{sentence}} + + Answer: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_a1 + reference: '' + a2: !Template + answer_choices: bad ||| good + id: a2 + jinja: 'Review: {{sentence}} + + Answer: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_a2 + reference: '' + a3: !Template + answer_choices: bad ||| good + id: a3 + jinja: 'My review for last night''s film: {{sentence}} The critics agreed that this movie was |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: last_night_film + reference: '' + a4: !Template + answer_choices: negative ||| positive + id: a4 + jinja: 'One of our critics wrote "{{sentence}}". Her sentiment towards the film was |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: critic_sentiment + reference: '' + a5: !Template + answer_choices: bad ||| good + id: a5 + jinja: 'In a contemporary review, Roger Ebert wrote: "{{sentence}}". Entertainment Weekly agreed, and the overall critical reception of the film was |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: roger_ebert + reference: '' + a6: !Template + answer_choices: No ||| Yes + id: a6 + jinja: 'Review: {{sentence}} + + Positive Review? |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_positive + reference: '' + a7: !Template + answer_choices: Negative ||| Positive + id: a7 + jinja: 'Review: {{sentence}} + + Question: Is the sentiment of the above review Positive or Negative? + + Answer: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_sentiment_question + reference: '' + a8: !Template + answer_choices: bad ||| good + id: a8 + jinja: 'Review: {{sentence}} + + Question: Did the author think that the movie was good or bad? + + Answer: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: author_opinion_question + reference: '' + a9: !Template + answer_choices: bad ||| good + id: a9 + jinja: 'Question: Did the author of the following tweet think that the movie was good or bad? + + Tweet: {{sentence}} + + Answer: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: tweet_opinion_question + reference: '' + a10: !Template + answer_choices: bad ||| good + id: a10 + jinja: '{{sentence}} My overall feeling was that the movie was |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: overall_feeling + reference: '' + a11: !Template + answer_choices: liked ||| hated + id: a11 + jinja: '{{sentence}} I |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: liked_or_hated + reference: '' + a12: !Template + answer_choices: 0 ||| 5 + id: a12 + jinja: '{{sentence}} My friend asked me if I would give the movie 0 or 5 stars, I said |||{{answer_choices[label]}}.' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: movie_stars + reference: '' + a13: !Template + answer_choices: Negative ||| Positive + id: a13 + jinja: 'Input: {{sentence}} + + Sentiment: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: input_sentiment + reference: '' + a14: !Template + answer_choices: False ||| True + id: a14 + jinja: 'Review: {{sentence}} + + Positive: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_positive_boolean + reference: '' + a15: !Template + answer_choices: 0 ||| 5 + id: a15 + jinja: 'Review: {{sentence}} + + Stars: |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: review_stars + reference: '' diff --git a/elk/run.py b/elk/run.py index adb38aa3..bc71aa82 100644 --- a/elk/run.py +++ b/elk/run.py @@ -96,14 +96,11 @@ def prepare_data( method=self.cfg.normalization, ) - x0, x1 = train_h.unbind(dim=-2) - val_x0, val_x1 = val_h.unbind(dim=-2) - - with self.dataset.formatted_as("numpy"): + with self.dataset.formatted_as("torch"): has_preds = "model_preds" in val.features val_lm_preds = val["model_preds"] if has_preds else None - return x0, x1, val_x0, val_x1, train_labels, val_labels, val_lm_preds + return train_h, val_h, train_labels, val_labels, val_lm_preds def concatenate(self, layers): """Concatenate hidden states from a previous layer.""" @@ -134,7 +131,8 @@ def apply_to_layers( layers = self.concatenate(layers) # Should we write to different CSV files for elicit vs eval? - with mp.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: + ctx = mp.get_context("spawn") + with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: mapper = pool.imap_unordered if num_devices > 1 else map row_buf = [] @@ -143,7 +141,8 @@ def apply_to_layers( row_buf.append(row) finally: # Make sure the CSV is written even if we crash or get interrupted - df = pd.DataFrame(row_buf).sort_values(by="layer") - df.to_csv(f, index=False) + if row_buf: + df = pd.DataFrame(row_buf).sort_values(by="layer") + df.to_csv(f, index=False) if self.cfg.debug: save_debug_log(self.dataset, self.out_dir) diff --git a/elk/training/baseline.py b/elk/training/baseline.py index 2c9542a6..7f09a403 100644 --- a/elk/training/baseline.py +++ b/elk/training/baseline.py @@ -1,11 +1,11 @@ import pickle from pathlib import Path -from typing import Tuple import torch -from sklearn.metrics import accuracy_score, roc_auc_score +from einops import rearrange, repeat from torch import Tensor +from ..metrics import accuracy, mean_auc, to_one_hot from ..utils.typing import assert_type from .classifier import Classifier @@ -13,41 +13,36 @@ def evaluate_baseline( - lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_labels: Tensor -) -> Tuple[float, float]: - X = torch.cat([val_x0, val_x1]) - d = X.shape[-1] - X_val = X.view(-1, d) + lr_model: Classifier, hiddens: Tensor, labels: Tensor +) -> tuple[float, float]: + # n = batch, v = variants, c = classes, d = hidden dim + (_, v, c, _) = hiddens.shape + + Y = repeat(labels, "n -> (n v)", v=v) + Y_one_hot = to_one_hot(Y, n_classes=c).long().flatten() + X = rearrange(hiddens, "n v c d -> (n v c) d") with torch.no_grad(): - lr_preds = lr_model(X_val).sigmoid().cpu() + lr_preds = lr_model(X) - val_labels_aug = ( - torch.cat([val_labels, 1 - val_labels]).repeat_interleave(val_x0.shape[1]) - ).cpu() - - lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) - lr_auroc = roc_auc_score(val_labels_aug, lr_preds) + # Top-1 accuracy + lr_acc = accuracy( + Y.cpu(), rearrange(lr_preds.squeeze(-1), "(n v c) -> (n v) c", v=v, c=c).cpu() + ) + lr_auroc = mean_auc(Y_one_hot.cpu(), lr_preds.cpu(), curve="roc") return assert_type(float, lr_auroc), assert_type(float, lr_acc) -def train_baseline( - x0: Tensor, - x1: Tensor, - train_labels: Tensor, - device: str, -) -> Classifier: - # repeat_interleave makes `num_variants` copies of each label, all within a - # single dimension of size `num_variants * 2 * n`, such that the labels align - # with X.view(-1, X.shape[-1]) - train_labels_aug = torch.cat([train_labels, 1 - train_labels]).repeat_interleave( - x0.shape[1] - ) +def train_baseline(hiddens: Tensor, labels: Tensor) -> Classifier: + # n = batch, v = variants, c = classes, d = hidden dim + (_, v, c, d) = hiddens.shape + + Y = repeat(labels, "n -> (n v)", v=v) + Y = to_one_hot(Y, n_classes=c).long().flatten() + X = rearrange(hiddens, "n v c d -> (n v c) d") - X = torch.cat([x0, x1]).squeeze() - d = X.shape[-1] - lr_model = Classifier(d, device=device) - lr_model.fit_cv(X.view(-1, d), train_labels_aug) + lr_model = Classifier(d, device=X.device) + lr_model.fit(X, Y) return lr_model diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 258f2bc4..5ec2534a 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -165,6 +165,9 @@ def forward(self, x: Tensor) -> Tensor: return self.probe(x).squeeze(-1) def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: + return self.predict_prob(x_pos, x_neg).logit() + + def predict_prob(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: return 0.5 * (self(x_pos).sigmoid() + (1 - self(x_neg).sigmoid())) def loss( diff --git a/elk/training/classifier.py b/elk/training/classifier.py index 726cae7a..a140af44 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import torch from torch import Tensor @@ -35,14 +34,14 @@ class Classifier(torch.nn.Module): def __init__( self, input_dim: int, - num_classes: int = 1, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, + num_classes: int = 2, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, ): super().__init__() self.linear = torch.nn.Linear( - input_dim, num_classes, device=device, dtype=dtype + input_dim, num_classes if num_classes > 2 else 1, device=device, dtype=dtype ) self.linear.bias.data.zero_() self.linear.weight.data.zero_() @@ -85,7 +84,9 @@ def fit( num_classes = self.linear.out_features loss_fn = bce_with_logits if num_classes == 1 else cross_entropy loss = torch.inf - y = y.float() + y = y.to( + torch.get_default_dtype() if num_classes == 1 else torch.long, + ) def closure(): nonlocal loss @@ -148,11 +149,13 @@ def fit_cv( indices = torch.randperm(num_samples, device=x.device, generator=rng) l2_penalties = torch.logspace(-4, 4, num_penalties).tolist() - y = y.float() num_classes = self.linear.out_features loss_fn = bce_with_logits if num_classes == 1 else cross_entropy losses = x.new_zeros((k, num_penalties)) + y = y.to( + torch.get_default_dtype() if num_classes == 1 else torch.long, + ) for i in range(k): start, end = i * fold_size, (i + 1) * fold_size diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 1d2f6a98..6fef0ef7 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -5,6 +5,7 @@ from warnings import warn import torch +from einops import rearrange, repeat from torch import Tensor, nn, optim from ..math_util import cov_mean_fused @@ -63,12 +64,13 @@ class EigenReporter(Reporter): def __init__( self, - in_features: int, cfg: EigenReporterConfig, + in_features: int, + num_classes: int = 2, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): - super().__init__(in_features, cfg, device=device, dtype=dtype) + super().__init__(cfg, in_features, num_classes, device=device, dtype=dtype) # Learnable Platt scaling parameters self.bias = nn.Parameter(torch.zeros(cfg.num_heads, device=device, dtype=dtype)) @@ -96,9 +98,24 @@ def forward(self, x: Tensor) -> Tensor: raw_scores = x @ self.weight.mT return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: - """Return the predicted log odds on the contrast pair `(x_pos, x_neg)`.""" - return 0.5 * (self(x_pos) - self(x_neg)) + def predict(self, *hiddens: Tensor) -> Tensor: + """Return the predicted logits on the contrast set `hiddens`.""" + # breakpoint() + if len(hiddens) == 1: + return self(hiddens[0]) + + elif len(hiddens) == 2: + return 0.5 * (self(hiddens[0]) - self(hiddens[1])) + else: + return torch.stack(list(map(self, hiddens)), dim=-1) + + def predict_prob(self, *hiddens: Tensor) -> Tensor: + """Return the predicted probabilities on the contrast set `hiddens`.""" + logits = self.predict(*hiddens) + if len(hiddens) == 2: + return logits.sigmoid() + else: + return logits.softmax(dim=-1) @property def contrastive_xcov(self) -> Tensor: @@ -128,47 +145,58 @@ def clear(self) -> None: self.n.zero_() @torch.no_grad() - def update(self, x_pos: Tensor, x_neg: Tensor) -> None: - # Sanity checks - assert x_pos.ndim == 3, "x_pos must be of shape [batch, num_variants, d]" - assert x_pos.shape == x_neg.shape, "x_pos and x_neg must have the same shape" + def update(self, *hiddens: Tensor) -> None: + k = len(hiddens) + assert k > 1, "Must provide at least two hidden states" - # Average across variants inside each cluster, computing the centroids. - pos_centroids, neg_centroids = x_pos.mean(1), x_neg.mean(1) + # Sanity checks + pivot, *rest = hiddens + assert pivot.ndim == 3, "hidden must be of shape [batch, num_variants, d]" + for h in rest: + assert h.shape == pivot.shape, "All hiddens must have the same shape" # We don't actually call super because we need access to the earlier estimate # of the population mean in order to update (cross-)covariances properly # super().update(x_pos, x_neg) - sample_n = pos_centroids.shape[0] + sample_n = pivot.shape[0] self.n += sample_n - # Update the running means; super().update() does this usually - neg_delta = neg_centroids - self.neg_mean - pos_delta = pos_centroids - self.pos_mean - self.neg_mean += neg_delta.sum(dim=0) / self.n - self.pos_mean += pos_delta.sum(dim=0) / self.n - - # *** Variance (inter-cluster) *** - # See code at https://bit.ly/3YC9BhH, as well as "Welford's online algorithm" - # in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. - # Post-mean update deltas are used to update the (co)variance - neg_delta2 = neg_centroids - self.neg_mean # [n, d] - pos_delta2 = pos_centroids - self.pos_mean # [n, d] - self.intercluster_cov_M2.addmm_(neg_delta.mT, neg_delta2) - self.intercluster_cov_M2.addmm_(pos_delta.mT, pos_delta2) - # *** Invariance (intra-cluster) *** # This is just a standard online *mean* update, since we're computing the # mean of covariance matrices, not the covariance matrix of means. - sample_invar = cov_mean_fused(x_pos) + cov_mean_fused(x_neg) + sample_invar = sum(map(cov_mean_fused, hiddens)) / k self.intracluster_cov += (sample_n / self.n) * ( sample_invar - self.intracluster_cov ) - # *** Negative covariance *** - self.contrastive_xcov_M2.addmm_(neg_delta.mT, pos_delta2) - self.contrastive_xcov_M2.addmm_(pos_delta.mT, neg_delta2) + # [n, v, d] -> [n, d] + centroids = [h.mean(1) for h in hiddens] + deltas, deltas2 = [], [] + + for i, h in enumerate(centroids): + # Update the running means; super().update() does this usually + delta = h - self.class_means[i] + self.class_means[i] += delta.sum(dim=0) / self.n + + # *** Variance (inter-cluster) *** + # See code at https://bit.ly/3YC9BhH and "Welford's online algorithm" + # in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. + # Post-mean update deltas are used to update the (co)variance + delta2 = h - self.class_means[i] # [n, d] + self.intercluster_cov_M2.addmm_(delta.mT, delta2, alpha=1 / k) + deltas.append(delta) + deltas2.append(delta2) + + # *** Negative covariance (contrastive) *** + for i, d in enumerate(deltas): + for j, d_ in enumerate(deltas2): + # Compare to the other classes only + if i == j: + continue + + scale = 1 / (k * (k - 1)) + self.contrastive_xcov_M2.addmm_(d.mT, d_, alpha=scale) def fit_streaming(self) -> float: """Fit the probe using the current streaming statistics.""" @@ -180,7 +208,7 @@ def fit_streaming(self) -> float: try: L, Q = truncated_eigh(A, k=self.config.num_heads) - except ConvergenceError: + except (ConvergenceError, RuntimeError): warn( "Truncated eigendecomposition failed to converge. Falling back on " "PyTorch's dense eigensolver." @@ -194,50 +222,48 @@ def fit_streaming(self) -> float: def fit( self, - x_pos: Tensor, - x_neg: Tensor, + *hiddens: Tensor, labels: Optional[Tensor] = None, - *, - platt_scale: bool = True, ) -> float: - """Fit the probe to the contrast pair (x_pos, x_neg). + """Fit the probe to the contrast set `hiddens`. Args: - x_pos: The positive examples. - x_neg: The negative examples. + hiddens: The contrast set of hidden states. labels: The ground truth labels if available. - platt_scale: Whether to fit the scale and bias terms to data with LBFGS. - This is only used if labels are available. Returns: loss: Negative eigenvalue associated with the VINC direction. """ - assert x_pos.shape == x_neg.shape - self.update(x_pos, x_neg) + self.update(*hiddens) loss = self.fit_streaming() - if labels is not None and platt_scale: - self.platt_scale(labels, x_pos, x_neg) + + if labels is not None: + self.platt_scale(labels, *hiddens) return loss - def platt_scale( - self, labels: Tensor, x_pos: Tensor, x_neg: Tensor, max_iter: int = 100 - ): + def platt_scale(self, labels: Tensor, *hiddens: Tensor, max_iter: int = 100): """Fit the scale and bias terms to data with LBFGS.""" + pivot, *_ = hiddens opt = optim.LBFGS( [self.bias, self.scale], line_search_fn="strong_wolfe", max_iter=max_iter, - tolerance_change=torch.finfo(x_pos.dtype).eps, - tolerance_grad=torch.finfo(x_pos.dtype).eps, + tolerance_change=torch.finfo(pivot.dtype).eps, + tolerance_grad=torch.finfo(pivot.dtype).eps, ) - labels = labels.repeat_interleave(x_pos.shape[1]).float() + labels = repeat(labels, "n -> (n v)", v=pivot.shape[1]) def closure(): opt.zero_grad() - logits = self.predict(x_pos, x_neg).flatten() - loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) + logits = rearrange(self.predict(*hiddens), "n v ... -> (n v) ...") + if len(logits.shape) == 1: + loss = nn.functional.binary_cross_entropy_with_logits( + logits, labels.float() + ) + else: + loss = nn.functional.cross_entropy(logits, labels.long()) loss.backward() return float(loss) diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 9cdfb145..fd92b7ce 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -7,11 +7,13 @@ import torch import torch.nn as nn +from einops import rearrange, repeat from simple_parsing.helpers import Serializable from sklearn.metrics import roc_auc_score from torch import Tensor from ..calibration import CalibrationError +from ..metrics import mean_auc, to_one_hot from .classifier import Classifier @@ -66,13 +68,13 @@ class Reporter(nn.Module, ABC): """ n: Tensor - neg_mean: Tensor - pos_mean: Tensor + class_means: Tensor def __init__( self, - in_features: int, cfg: ReporterConfig, + in_features: int, + num_classes: int = 2, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): @@ -81,10 +83,8 @@ def __init__( self.config = cfg self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long)) self.register_buffer( - "neg_mean", torch.zeros(in_features, device=device, dtype=dtype) - ) - self.register_buffer( - "pos_mean", torch.zeros(in_features, device=device, dtype=dtype) + "class_means", + torch.zeros(num_classes, in_features, device=device, dtype=dtype), ) @classmethod @@ -139,15 +139,18 @@ def reset_parameters(self): """Reset the parameters of the probe.""" @torch.no_grad() - def update(self, x_pos: Tensor, x_neg: Tensor) -> None: + def update(self, *hiddens: Tensor) -> None: """Update the running mean of the positive and negative examples.""" - x_pos, x_neg = x_pos.flatten(0, -2), x_neg.flatten(0, -2) - self.n += x_pos.shape[0] + assert len(hiddens) > 1, "Must provide at least two hidden representations" + + # Flatten the hidden representations + hiddens = tuple(h.flatten(0, -2) for h in hiddens) + self.n += hiddens[0].shape[0] # Update the running means - self.neg_mean += (x_neg.sum(dim=0) - self.neg_mean) / self.n - self.pos_mean += (x_pos.sum(dim=0) - self.pos_mean) / self.n + for i, h in enumerate(hiddens): + self.class_means[i] += (h.sum(dim=0) - self.class_means[i]) / self.n # TODO: These methods will do something fancier in the future @classmethod @@ -162,55 +165,62 @@ def save(self, path: Path | str): @abstractmethod def fit( self, - x_pos: Tensor, - x_neg: Tensor, + *hiddens: Tensor, labels: Optional[Tensor] = None, ) -> float: ... @abstractmethod - def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: - """Pool the probe output on the contrast pair (x_pos, x_neg).""" + def predict(self, *hiddens: Tensor) -> Tensor: + """Return pooled logits for the contrast set `hiddens`.""" + + @abstractmethod + def predict_prob(self, *hiddens: Tensor) -> Tensor: + """Like `predict` but returns normalized probabilities, not logits.""" @torch.no_grad() - def score(self, labels: Tensor, x_pos: Tensor, x_neg: Tensor) -> EvalResult: - """Score the probe on the contrast pair (x_pos, x1). + def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: + """Score the probe on the contrast set `hiddens`. Args: - x_pos: The positive examples. - x_neg: The negative examples. - labels: The labels of the contrast pair. + labels: The labels of the contrast pair. + hiddens: The hidden representations of the contrast set. Returns: an instance of EvalResult containing the loss, accuracy, calibrated - accuracy, and AUROC of the probe on the contrast pair (x0, x1). + accuracy, and AUROC of the probe on `hiddens`. """ + pred_probs = self.predict_prob(hiddens) + (_, v, c) = pred_probs.shape - pred_probs = self.predict(x_pos, x_neg) + # makes `num_variants` copies of each label + Y = repeat(labels, "n -> (n v)", v=v).float() + to_one_hot(Y, n_classes=c).long().flatten() - # makes `num_variants` copies of each label, all within a single - # dimension of size `num_variants * n`, such that the labels align - # with pred_probs.flatten() - broadcast_labels = labels.repeat_interleave(pred_probs.shape[1]).float() - cal_err = ( - CalibrationError() - .update(broadcast_labels.cpu(), pred_probs.cpu()) - .compute() - ) + if c == 2: + cal_err = CalibrationError().update(Y.cpu(), pred_probs.cpu()).compute().ece + # Calibrated accuracy + cal_thresh = pred_probs.float().quantile(labels.float().mean()) + cal_preds = pred_probs.gt(cal_thresh).squeeze(1).to(torch.int) + cal_acc = cal_preds.flatten().eq(Y).float().mean().item() - # Calibrated accuracy - cal_thresh = pred_probs.float().quantile(labels.float().mean()) - cal_preds = pred_probs.gt(cal_thresh).squeeze(1).to(torch.int) - raw_preds = pred_probs.gt(0.5).squeeze(1).to(torch.int) + raw_preds = pred_probs.gt(0.5).squeeze(1).to(torch.int) + else: + # TODO: Implement calibration error for k > 2? + cal_acc = 0.0 + cal_err = 0.0 + + raw_preds = pred_probs.argmax(dim=-1) # roc_auc_score only takes flattened input - auroc = float(roc_auc_score(broadcast_labels.cpu(), pred_probs.cpu().flatten())) - cal_acc = cal_preds.flatten().eq(broadcast_labels).float().mean() - raw_acc = raw_preds.flatten().eq(broadcast_labels).float().mean() + auroc = mean_auc( + Y.cpu(), rearrange(pred_probs.cpu(), "n v ... -> (n v) ..."), curve="roc" + ) + raw_acc = raw_preds.flatten().eq(Y).float().mean() return EvalResult( - acc=torch.max(raw_acc, 1 - raw_acc).item(), - cal_acc=torch.max(cal_acc, 1 - cal_acc).item(), - auroc=max(auroc, 1 - auroc), - ece=cal_err.ece, + acc=raw_acc.item(), + cal_acc=cal_acc, + auroc=float(auroc), + ece=cal_err, ) diff --git a/elk/training/train.py b/elk/training/train.py index 9be8c589..d68909df 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,6 +1,5 @@ """Main training loop.""" -import warnings from dataclasses import dataclass from functools import partial from pathlib import Path @@ -8,18 +7,18 @@ import pandas as pd import torch +from einops import rearrange, repeat from simple_parsing import Serializable, field, subgroups -from sklearn.metrics import accuracy_score, roc_auc_score -from torch import Tensor from ..extraction.extraction import Extract +from ..metrics import accuracy, mean_auc from ..run import Run from ..training.baseline import evaluate_baseline, save_baseline, train_baseline from ..utils import select_usable_devices from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig -from .reporter import OptimConfig, Reporter, ReporterConfig +from .reporter import OptimConfig, ReporterConfig @dataclass @@ -84,30 +83,29 @@ def train_reporter( device = self.get_device(devices, world_size) - x0, x1, val_x0, val_x1, train_gt, val_gt, val_lm_preds = self.prepare_data( + train_h, val_h, train_gt, val_gt, val_lm_preds = self.prepare_data( device, layer ) - pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1) + (_, v, c, d) = train_h.shape + # pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1) if isinstance(self.cfg.net, CcsReporterConfig): - reporter = CcsReporter(x0.shape[-1], self.cfg.net, device=device) + reporter = CcsReporter(d, self.cfg.net, device=device) elif isinstance(self.cfg.net, EigenReporterConfig): - reporter = EigenReporter(x0.shape[-1], self.cfg.net, device=device) + reporter = EigenReporter(self.cfg.net, d, c, device=device) else: raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") - train_loss = reporter.fit(x0, x1, train_gt) - val_result = reporter.score( - val_gt, - val_x0, - val_x1, - ) + train_loss = reporter.fit(*train_h.unbind(2), labels=train_gt) + val_result = reporter.score(val_gt, val_h) reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) if val_lm_preds is not None: - val_gt_cpu = val_gt.repeat_interleave(val_lm_preds.shape[1]).float().cpu() - val_lm_auroc = float(roc_auc_score(val_gt_cpu, val_lm_preds.flatten())) - val_lm_acc = float(accuracy_score(val_gt_cpu, val_lm_preds.flatten() > 0.5)) + val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu() + val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...") + val_lm_auroc = mean_auc(val_gt_cpu, val_lm_preds, "roc") + + val_lm_acc = accuracy(val_gt_cpu, val_lm_preds) else: val_lm_auroc = None val_lm_acc = None @@ -115,7 +113,7 @@ def train_reporter( row = pd.Series( { "layer": layer, - "pseudo_auroc": pseudo_auroc, + # "pseudo_auroc": pseudo_auroc, "train_loss": train_loss, **val_result._asdict(), "lm_auroc": val_lm_auroc, @@ -124,9 +122,8 @@ def train_reporter( ) if not self.cfg.skip_baseline: - lr_model = train_baseline(x0, x1, train_gt, device=device) - - lr_auroc, lr_acc = evaluate_baseline(lr_model, val_x0, val_x1, val_gt) + lr_model = train_baseline(train_h, train_gt) + lr_auroc, lr_acc = evaluate_baseline(lr_model, val_h, val_gt) row["lr_auroc"] = lr_auroc row["lr_acc"] = lr_acc @@ -137,23 +134,23 @@ def train_reporter( return row - def get_pseudo_auroc( - self, layer: int, x0: Tensor, x1: Tensor, val_x0: Tensor, val_x1: Tensor - ): - """Check the separability of the pseudo-labels at a given layer.""" - - with torch.no_grad(): - pseudo_auroc = Reporter.check_separability( - train_pair=(x0, x1), val_pair=(val_x0, val_x1) - ) - if pseudo_auroc > 0.6: - warnings.warn( - f"The pseudo-labels at layer {layer} are linearly separable with " - f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the " - f"algorithm will not converge to a good solution." - ) - - return pseudo_auroc + # def get_pseudo_auroc( + # self, layer: int, train_h: Tensor, val_h: Tensor + # ): + # """Check the separability of the pseudo-labels at a given layer.""" + # + # with torch.no_grad(): + # pseudo_auroc = Reporter.check_separability( + # train_pair=(x0, x1), val_pair=(val_x0, val_x1) + # ) + # if pseudo_auroc > 0.6: + # warnings.warn( + # f"The pseudo-labels at layer {layer} are linearly separable with " + # f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the " + # f"algorithm will not converge to a good solution." + # ) + # + # return pseudo_auroc def train(self): """Train a reporter on each layer of the network.""" diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index 1400a98d..deb002e4 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -12,17 +12,17 @@ from .typing import assert_type, float32_to_int16, int16_to_float32 __all__ = [ + "assert_type", "binarize", "convert_span", + "float32_to_int16", "get_columns_all_equal", "infer_label_column", "infer_num_classes", "instantiate_model", - "is_autoregressive", - "float32_to_int16", "int16_to_float32", + "is_autoregressive", + "pytree_map", "select_train_val_splits", "select_usable_devices", - "pytree_map", - "assert_type", ] diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 393ee6a6..35a2b7f2 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -2,7 +2,7 @@ from bisect import bisect_left, bisect_right from operator import itemgetter from random import Random -from typing import Any, Iterable, List +from typing import Any, Iterable from datasets import ( ClassLabel, @@ -101,11 +101,12 @@ def infer_num_classes(label_feature: Any) -> int: ) -def get_layers(ds: DatasetDict) -> List[int]: +def get_layers(ds: DatasetDict) -> list[int]: """Get a list of indices of hidden layers given a `DatasetDict`.""" + train, _ = select_train_val_splits(ds.keys()) layers = [ int(feat[len("hidden_") :]) - for feat in ds["train"].features + for feat in ds[train].features if feat.startswith("hidden_") ] return layers diff --git a/pyproject.toml b/pyproject.toml index 16edd58e..f688416d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ license = {text = "MIT License"} dependencies = [ # Added distributed.split_dataset_by_node for IterableDatasets "datasets>=2.9.0", + "einops", # Introduced numpy.typing module "numpy>=1.20.0", # This version is old, but it's needed for certain HF tokenizers to work. diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index 58dd6c13..a34fe2af 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -14,27 +14,27 @@ def test_eigen_reporter(): x_pos1, x_pos2 = x_pos.chunk(2, dim=0) x_neg1, x_neg2 = x_neg.chunk(2, dim=0) - reporter = EigenReporter(hidden_size, EigenReporterConfig(), dtype=torch.float64) + reporter = EigenReporter(EigenReporterConfig(), hidden_size, dtype=torch.float64) reporter.update(x_pos1, x_neg1) reporter.update(x_pos2, x_neg2) # Check that the streaming mean is correct pos_mu, neg_mu = x_pos.mean(dim=(0, 1)), x_neg.mean(dim=(0, 1)) - torch.testing.assert_close(reporter.pos_mean, pos_mu) - torch.testing.assert_close(reporter.neg_mean, neg_mu) + torch.testing.assert_close(reporter.class_means[0], pos_mu) + torch.testing.assert_close(reporter.class_means[1], neg_mu) # Check that the streaming covariance is correct pos_centroids, neg_centroids = x_pos.mean(dim=1), x_neg.mean(dim=1) - expected_var = batch_cov(pos_centroids) + batch_cov(neg_centroids) + expected_var = 0.5 * (batch_cov(pos_centroids) + batch_cov(neg_centroids)) torch.testing.assert_close(reporter.intercluster_cov, expected_var) # Check that the streaming invariance (intra-cluster variance) is correct - expected_invariance = cov_mean_fused(x_pos) + cov_mean_fused(x_neg) + expected_invariance = 0.5 * (cov_mean_fused(x_pos) + cov_mean_fused(x_neg)) torch.testing.assert_close(reporter.intracluster_cov, expected_invariance) # Check that the streaming negative covariance is correct cross_cov = (pos_centroids - pos_mu).mT @ (neg_centroids - neg_mu) / num_clusters - cross_cov = cross_cov + cross_cov.mT + cross_cov = 0.5 * (cross_cov + cross_cov.mT) torch.testing.assert_close(reporter.contrastive_xcov, cross_cov) assert reporter.n == num_clusters diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 87c1ac0c..c5a142c1 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -41,7 +41,7 @@ def test_output_is_roughly_balanced(): ) col = infer_label_column(dataset.features) - reservoir = BalancedSampler(dataset) + reservoir = BalancedSampler(dataset, 2) # Count the number of samples for each label counter = Counter() From f5477442b0aed7906a0364ebe271379dbf691373 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Mon, 10 Apr 2023 20:42:31 +0000 Subject: [PATCH 16/43] Revert changes to binarize --- elk/extraction/prompt_loading.py | 6 ++++-- elk/utils/data_utils.py | 8 +++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 7375af59..f3e5ac30 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -216,12 +216,15 @@ def qa_cat(q: str, a: str) -> str: # For sanity checking that prompts are unique prompt_counter = Counter() + new_label = rng.choice([0, 1]) if num_classes > 2 else example[label_column] for template in templates: choices = [] if num_classes > 2: - template, label = binarize(template, label, rng) + template = binarize( + template, example[label_column], assert_type(int, new_label), rng + ) for answer_idx in range(2): fake_example = example.copy() @@ -229,7 +232,6 @@ def qa_cat(q: str, a: str) -> str: q, a = template.apply(fake_example) text = qa_cat(q, a) - prompt_counter[text] += 1 if fewshot_iter is not None: # Infinite iterator so we don't need to worry about StopIteration diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 393ee6a6..a98a7aae 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -111,9 +111,8 @@ def get_layers(ds: DatasetDict) -> List[int]: return layers -def binarize(template: Template, label: int, rng: Random) -> tuple[Template, int]: +def binarize(template: Template, label: int, new_label: int, rng: Random) -> Template: """Binarize a template with >2 answer choices, returning a new template and label. - Returns: `new_template`: A deepcopy of the original template with with 2 answer choices, one of @@ -132,12 +131,11 @@ def binarize(template: Template, label: int, rng: Random) -> tuple[Template, int true = answer_choices[label] false = rng.choice([c for c in answer_choices if c != true]) - # What order are we going to present the answer choices in? - new_label = rng.choice([0, 1]) + assert new_label in (0, 1) new_template = copy.deepcopy(template) new_template.answer_choices = ( f"{false} ||| {true}" if new_label else f"{true} ||| {false}" ) - return new_template, new_label + return new_template From ab1909f025d08818ccfe8bc1975f6dbebe9aa125 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Mon, 10 Apr 2023 21:20:16 +0000 Subject: [PATCH 17/43] Stupid prompt_counter bug --- elk/extraction/prompt_loading.py | 1 + 1 file changed, 1 insertion(+) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index f3e5ac30..c46d3d7b 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -232,6 +232,7 @@ def qa_cat(q: str, a: str) -> str: q, a = template.apply(fake_example) text = qa_cat(q, a) + prompt_counter[text] += 1 if fewshot_iter is not None: # Infinite iterator so we don't need to worry about StopIteration From f912ee6e66d100bb3fa47ae5c8c154958d76a2ae Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Mon, 10 Apr 2023 21:32:53 +0000 Subject: [PATCH 18/43] Remove stupid second set_start_method call --- elk/extraction/extraction.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 71a363e5..74f4e489 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -322,10 +322,6 @@ def get_splits() -> SplitDict: mp.set_start_method("spawn", force=True) # type: ignore[attr-defined] - import multiprocess as mp - - mp.set_start_method("spawn") # type: ignore[attr-defined] - ds = dict() for split, builder in builders.items(): builder.download_and_prepare(num_proc=len(devices)) From 83b480bbf28a52913672eabf91ef8f9a1cec1f14 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 11 Apr 2023 02:05:43 +0000 Subject: [PATCH 19/43] Fix bugs in binary case --- elk/extraction/extraction.py | 3 +- elk/extraction/prompt_loading.py | 112 ++++++++++++++----------------- elk/metrics.py | 3 +- elk/training/ccs_reporter.py | 4 +- elk/training/eigen_reporter.py | 7 +- elk/training/reporter.py | 20 +++--- elk/training/train.py | 4 +- 7 files changed, 70 insertions(+), 83 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index fcb595a2..84f43118 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -33,7 +33,6 @@ select_train_val_splits, select_usable_devices, ) -from .balanced_sampler import BalancedSampler from .generator import _GeneratorBuilder from .prompt_loading import PromptConfig, load_prompts @@ -126,7 +125,7 @@ def extract_hiddens( if rank == world_size - 1: max_examples += global_max_examples % world_size - for example in islice(BalancedSampler(prompt_ds, 3), max_examples): + for example in islice(prompt_ds, max_examples): num_variants = len(example["prompts"]) num_choices = len(example["prompts"][0]) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 0f891f96..213fd634 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -1,11 +1,11 @@ from collections import Counter from dataclasses import dataclass +from itertools import cycle from random import Random from typing import Any, Iterator, Literal, Optional from datasets import ( Dataset, - Features, load_dataset, ) from datasets.distributed import split_dataset_by_node @@ -18,7 +18,7 @@ infer_num_classes, select_train_val_splits, ) -from .balanced_sampler import FewShotSampler +from .balanced_sampler import BalancedSampler @dataclass @@ -95,10 +95,12 @@ def load_prompts( Returns: An iterable dataset of prompts. """ + class_counts = [] prompters = [] - raw_datasets = [] + datasets = [] train_datasets = [] rng = Random(seed) + assert num_shots == 0 # First load the datasets and prompters. We need to know the minimum number of # templates for any dataset in order to make sure we don't run out of prompts. @@ -112,29 +114,37 @@ def load_prompts( train_name, val_name = select_train_val_splits(ds_dict) split_name = val_name if split_type == "val" else train_name - # Note that when streaming we can only approximately shuffle the dataset - # using a buffer. Streaming shuffling is NOT an adequate shuffle for - # datasets like IMDB, which are sorted by label. - bad_streaming_datasets = ["imdb"] - assert not ( - stream and ds_name in bad_streaming_datasets - ), f"Streaming is not supported for {ds_name}." - split = ds_dict[split_name].shuffle(seed=seed) + ds = ds_dict[split_name].shuffle(seed=seed) train_ds = ds_dict[train_name].shuffle(seed=seed) + if not stream: - split = assert_type(Dataset, split) - split = split.to_iterable_dataset().cast(split.features) + ds = assert_type(Dataset, ds) + if world_size > 1: + ds = ds.shard(world_size, rank) + + ds = ds.to_iterable_dataset().cast(ds.features) - # only keep the datapoints relevant to the current process - if world_size > 1: + elif world_size > 1: # This prints to stdout which is slightly annoying - split = split_dataset_by_node( - dataset=split, rank=rank, world_size=world_size - ) + ds = split_dataset_by_node(dataset=ds, rank=rank, world_size=world_size) - raw_datasets.append(split) + label_column = infer_label_column(ds.features) + num_classes = infer_num_classes(ds.features[label_column]) + if label_column != "label": + ds = ds.rename_column(label_column, "label") + train_ds = train_ds.rename_column(label_column, "label") + + class_counts.append(num_classes) + datasets.append(ds) train_datasets.append(train_ds) + # Number of classes should be the same for all datasets + num_classes, *rest = class_counts + if not all(num_classes == x for x in rest): + raise ValueError( + f"# classes should be the same for all datasets, but got {class_counts}" + ) + min_num_templates = min(len(prompter.templates) for prompter in prompters) num_variants = ( min_num_templates @@ -145,51 +155,29 @@ def load_prompts( if rank == 0: print(f"Using {num_variants} variants of each prompt") - ds_iterators = [iter(ds) for ds in raw_datasets] - while True: # terminates when the first dataset runs out of examples - for ds_iterator, ds, train_ds, prompter in zip( - ds_iterators, raw_datasets, train_datasets, prompters - ): - label_column = infer_label_column(ds.features) - num_classes = infer_num_classes(ds.features[label_column]) - - # Remove everything except the label column - extra_cols = list(assert_type(Features, ds.features)) - extra_cols.remove(label_column) - - if label_column != "label": - ds = ds.rename_column(label_column, "label") - if num_shots > 0: - fewshot = FewShotSampler( - train_ds, # TODO: not iterator - num_shots=num_shots, - rng=rng, - ) - fewshot_iter = iter(fewshot) - else: - fewshot_iter = None - - try: - example = next(ds_iterator) - except StopIteration: - return - - example = _convert_to_prompts( - example, - label_column=label_column, - num_classes=num_classes, - num_variants=num_variants, - prompter=prompter, - rng=rng, - fewshot_iter=fewshot_iter, - ) + ds_iters = [iter(BalancedSampler(ds, num_classes)) for ds in datasets] + for ds_iter, ds, prompter in cycle(zip(ds_iters, datasets, prompters)): + try: + example = next(ds_iter) + except StopIteration: + return + + example = _convert_to_prompts( + example, + label_column="label", + num_classes=num_classes, + num_variants=num_variants, + prompter=prompter, + rng=rng, + fewshot_iter=None, + ) - # Add the builder and config name to the records directly to make - # sure we don't forget what dataset they came from. - example["builder_name"] = ds.info.builder_name - example["config_name"] = ds.info.config_name + # Add the builder and config name to the records directly to make + # sure we don't forget what dataset they came from. + example["builder_name"] = ds.info.builder_name + example["config_name"] = ds.info.config_name - yield example + yield example def _convert_to_prompts( diff --git a/elk/metrics.py b/elk/metrics.py index a57b37bd..94a394d8 100644 --- a/elk/metrics.py +++ b/elk/metrics.py @@ -1,7 +1,6 @@ from functools import partial from typing import Literal -import torch from sklearn.metrics import average_precision_score, roc_auc_score from torch import Tensor @@ -17,7 +16,7 @@ def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: Returns: Tensor: A one-hot representation tensor of shape (N, n_classes). """ - one_hot_labels = torch.zeros(labels.size(0), n_classes, dtype=torch.float32) + one_hot_labels = labels.new_zeros(labels.size(0), n_classes) return one_hot_labels.scatter_(1, labels.unsqueeze(1).long(), 1) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 5ec2534a..a68ebdc2 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -83,12 +83,12 @@ class CcsReporter(Reporter): def __init__( self, - in_features: int, cfg: CcsReporterConfig, + in_features: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): - super().__init__(in_features, cfg, device=device, dtype=dtype) + super().__init__(cfg, in_features, device=device, dtype=dtype) hidden_size = cfg.hidden_size or 4 * in_features // 3 diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 6fef0ef7..317edae4 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -25,9 +25,9 @@ class EigenReporterConfig(ReporterConfig): of eigenvectors to compute from the VINC matrix. """ - var_weight: float = 1.0 - inv_weight: float = 5.0 - neg_cov_weight: float = 5.0 + var_weight: float = 0.2 + inv_weight: float = 1.0 + neg_cov_weight: float = 1.0 num_heads: int = 1 @@ -100,7 +100,6 @@ def forward(self, x: Tensor) -> Tensor: def predict(self, *hiddens: Tensor) -> Tensor: """Return the predicted logits on the contrast set `hiddens`.""" - # breakpoint() if len(hiddens) == 1: return self(hiddens[0]) diff --git a/elk/training/reporter.py b/elk/training/reporter.py index fd92b7ce..b936bb53 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -90,13 +90,13 @@ def __init__( @classmethod def check_separability( cls, - train_pair: tuple[Tensor, Tensor], - val_pair: tuple[Tensor, Tensor], + train_hiddens: Tensor, + val_hiddens: Tensor, ) -> float: """Measure how linearly separable the pseudo-labels are for a contrast pair. Args: - train_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + train_hiddens: Tensor of shape [n, ], where x0 and x1 are the contrastive representations. Used for training the classifier. val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the contrastive representations. Used for evaluating the classifier. @@ -104,8 +104,8 @@ def check_separability( Returns: The AUROC of a linear classifier fit on the pseudo-labels. """ - x0, x1 = train_pair - val_x0, val_x1 = val_pair + x0, x1 = train_hiddens + val_x0, val_x1 = val_hiddens pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore pseudo_train_labels = torch.cat( @@ -198,13 +198,15 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: to_one_hot(Y, n_classes=c).long().flatten() if c == 2: - cal_err = CalibrationError().update(Y.cpu(), pred_probs.cpu()).compute().ece + pos_probs = pred_probs[..., 0].flatten() + cal_err = CalibrationError().update(Y.cpu(), pos_probs.cpu()).compute().ece + # Calibrated accuracy - cal_thresh = pred_probs.float().quantile(labels.float().mean()) - cal_preds = pred_probs.gt(cal_thresh).squeeze(1).to(torch.int) + cal_thresh = pos_probs.float().quantile(labels.float().mean()) + cal_preds = pos_probs.gt(cal_thresh).to(torch.int) cal_acc = cal_preds.flatten().eq(Y).float().mean().item() - raw_preds = pred_probs.gt(0.5).squeeze(1).to(torch.int) + raw_preds = pos_probs.gt(0.5).to(torch.int) else: # TODO: Implement calibration error for k > 2? cal_acc = 0.0 diff --git a/elk/training/train.py b/elk/training/train.py index d68909df..931a4df5 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -90,14 +90,14 @@ def train_reporter( # pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1) if isinstance(self.cfg.net, CcsReporterConfig): - reporter = CcsReporter(d, self.cfg.net, device=device) + reporter = CcsReporter(self.cfg.net, d, device=device) elif isinstance(self.cfg.net, EigenReporterConfig): reporter = EigenReporter(self.cfg.net, d, c, device=device) else: raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") train_loss = reporter.fit(*train_h.unbind(2), labels=train_gt) - val_result = reporter.score(val_gt, val_h) + val_result = reporter.score(val_gt.to(device), val_h) reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) if val_lm_preds is not None: From 3e6626205d8133b4fac52f2abe9bff475707e34d Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 11 Apr 2023 17:16:49 +0000 Subject: [PATCH 20/43] Various little refactors --- elk/training/ccs_reporter.py | 11 ++-- elk/training/eigen_reporter.py | 101 +++++++++++++++------------------ elk/training/preprocessing.py | 2 +- elk/training/reporter.py | 8 +-- elk/training/train.py | 2 +- 5 files changed, 58 insertions(+), 66 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index a68ebdc2..0253857e 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -164,10 +164,11 @@ def forward(self, x: Tensor) -> Tensor: """Return the raw score output of the probe on `x`.""" return self.probe(x).squeeze(-1) - def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: - return self.predict_prob(x_pos, x_neg).logit() + def predict(self, hiddens: Tensor) -> Tensor: + return self.predict_prob(hiddens).logit() - def predict_prob(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: + def predict_prob(self, hiddens: Tensor) -> Tensor: + x_pos, x_neg = hiddens.unbind(2) return 0.5 * (self(x_pos).sigmoid() + (1 - self(x_neg).sigmoid())) def loss( @@ -216,8 +217,7 @@ def loss( def fit( self, - x_pos: Tensor, - x_neg: Tensor, + hiddens: Tensor, labels: Optional[Tensor] = None, ) -> float: """Fit the probe to the contrast pair (x0, x1). @@ -236,6 +236,7 @@ def fit( """ # TODO: Implement normalization here to fix issue #96 # self.update(x_pos, x_neg) + x_pos, x_neg = hiddens.unbind(2) # Record the best acc, loss, and params found so far best_loss = torch.inf diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 317edae4..b2d560b3 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -9,6 +9,7 @@ from torch import Tensor, nn, optim from ..math_util import cov_mean_fused +from ..metrics import to_one_hot from ..truncated_eigh import ConvergenceError, truncated_eigh from .reporter import Reporter, ReporterConfig @@ -19,18 +20,22 @@ class EigenReporterConfig(ReporterConfig): Args: var_weight: The weight of the variance term in the loss. - inv_weight: The weight of the invariance term in the loss. neg_cov_weight: The weight of the negative covariance term in the loss. num_heads: The number of reporter heads to fit. In other words, the number of eigenvectors to compute from the VINC matrix. """ - var_weight: float = 0.2 - inv_weight: float = 1.0 - neg_cov_weight: float = 1.0 + var_weight: float = 0.1 + neg_cov_weight: float = 0.5 num_heads: int = 1 + def __post_init__(self): + if not (0 <= self.neg_cov_weight <= 1): + raise ValueError("neg_cov_weight must be in [0, 1]") + if self.num_heads <= 0: + raise ValueError("num_heads must be positive") + class EigenReporter(Reporter): """A linear reporter whose weights are computed via eigendecomposition. @@ -93,24 +98,16 @@ def __init__( torch.zeros(cfg.num_heads, in_features, device=device, dtype=dtype), ) - def forward(self, x: Tensor) -> Tensor: + def forward(self, hiddens: Tensor) -> Tensor: """Return the predicted log odds on input `x`.""" - raw_scores = x @ self.weight.mT + raw_scores = hiddens @ self.weight.mT return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - def predict(self, *hiddens: Tensor) -> Tensor: - """Return the predicted logits on the contrast set `hiddens`.""" - if len(hiddens) == 1: - return self(hiddens[0]) - - elif len(hiddens) == 2: - return 0.5 * (self(hiddens[0]) - self(hiddens[1])) - else: - return torch.stack(list(map(self, hiddens)), dim=-1) + predict = forward - def predict_prob(self, *hiddens: Tensor) -> Tensor: + def predict_prob(self, hiddens: Tensor) -> Tensor: """Return the predicted probabilities on the contrast set `hiddens`.""" - logits = self.predict(*hiddens) + logits = self(hiddens) if len(hiddens) == 2: return logits.sigmoid() else: @@ -126,15 +123,15 @@ def intercluster_cov(self) -> Tensor: @property def confidence(self) -> Tensor: - return self.weight.mT @ self.intercluster_cov @ self.weight + return self.weight @ self.intercluster_cov @ self.weight.mT @property def invariance(self) -> Tensor: - return -self.weight.mT @ self.intracluster_cov @ self.weight + return -self.weight @ self.intracluster_cov @ self.weight.mT @property def consistency(self) -> Tensor: - return -self.weight.mT @ self.contrastive_xcov @ self.weight + return -self.weight @ self.contrastive_xcov @ self.weight.mT def clear(self) -> None: """Clear the running statistics of the reporter.""" @@ -144,36 +141,33 @@ def clear(self) -> None: self.n.zero_() @torch.no_grad() - def update(self, *hiddens: Tensor) -> None: - k = len(hiddens) - assert k > 1, "Must provide at least two hidden states" + def update(self, hiddens: Tensor) -> None: + (n, _, k, d) = hiddens.shape + + # Zero out shared info + hiddens = hiddens - hiddens.mean(dim=2, keepdim=True) # Sanity checks - pivot, *rest = hiddens - assert pivot.ndim == 3, "hidden must be of shape [batch, num_variants, d]" - for h in rest: - assert h.shape == pivot.shape, "All hiddens must have the same shape" + assert k > 1, "Must provide at least two hidden states" + assert hiddens.ndim == 4, "Must be of shape [batch, variants, choices, dim]" # We don't actually call super because we need access to the earlier estimate # of the population mean in order to update (cross-)covariances properly - # super().update(x_pos, x_neg) + # super().update(hiddens) - sample_n = pivot.shape[0] - self.n += sample_n + self.n += n # *** Invariance (intra-cluster) *** # This is just a standard online *mean* update, since we're computing the # mean of covariance matrices, not the covariance matrix of means. - sample_invar = sum(map(cov_mean_fused, hiddens)) / k - self.intracluster_cov += (sample_n / self.n) * ( - sample_invar - self.intracluster_cov - ) + intra_cov = cov_mean_fused(rearrange(hiddens, "n v k d -> (n k) v d")) + self.intracluster_cov += (n / self.n) * (intra_cov - self.intracluster_cov) - # [n, v, d] -> [n, d] - centroids = [h.mean(1) for h in hiddens] + # [n, v, k, d] -> [n, k, d] + centroids = hiddens.mean(1) deltas, deltas2 = [], [] - for i, h in enumerate(centroids): + for i, h in enumerate(centroids.unbind(1)): # Update the running means; super().update() does this usually delta = h - self.class_means[i] self.class_means[i] += delta.sum(dim=0) / self.n @@ -199,12 +193,12 @@ def update(self, *hiddens: Tensor) -> None: def fit_streaming(self) -> float: """Fit the probe using the current streaming statistics.""" + inv_weight = 1 - self.config.neg_cov_weight A = ( self.config.var_weight * self.intercluster_cov - - self.config.inv_weight * self.intracluster_cov + - inv_weight * self.intracluster_cov - self.config.neg_cov_weight * self.contrastive_xcov ) - try: L, Q = truncated_eigh(A, k=self.config.num_heads) except (ConvergenceError, RuntimeError): @@ -221,48 +215,45 @@ def fit_streaming(self) -> float: def fit( self, - *hiddens: Tensor, + hiddens: Tensor, labels: Optional[Tensor] = None, ) -> float: """Fit the probe to the contrast set `hiddens`. Args: - hiddens: The contrast set of hidden states. + hiddens: The contrast set of shape [batch, variants, choices, dim]. labels: The ground truth labels if available. Returns: loss: Negative eigenvalue associated with the VINC direction. """ - self.update(*hiddens) + self.update(hiddens) loss = self.fit_streaming() if labels is not None: - self.platt_scale(labels, *hiddens) + self.platt_scale(labels, hiddens) return loss - def platt_scale(self, labels: Tensor, *hiddens: Tensor, max_iter: int = 100): + def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100): """Fit the scale and bias terms to data with LBFGS.""" - pivot, *_ = hiddens opt = optim.LBFGS( [self.bias, self.scale], line_search_fn="strong_wolfe", max_iter=max_iter, - tolerance_change=torch.finfo(pivot.dtype).eps, - tolerance_grad=torch.finfo(pivot.dtype).eps, + tolerance_change=torch.finfo(hiddens.dtype).eps, + tolerance_grad=torch.finfo(hiddens.dtype).eps, ) - labels = repeat(labels, "n -> (n v)", v=pivot.shape[1]) + (_, v, k, _) = hiddens.shape + labels = to_one_hot(repeat(labels, "n -> (n v)", v=v), k) def closure(): opt.zero_grad() - logits = rearrange(self.predict(*hiddens), "n v ... -> (n v) ...") - if len(logits.shape) == 1: - loss = nn.functional.binary_cross_entropy_with_logits( - logits, labels.float() - ) - else: - loss = nn.functional.cross_entropy(logits, labels.long()) + logits = rearrange(self(hiddens), "n v k -> (n v) k") + loss = nn.functional.binary_cross_entropy_with_logits( + logits, labels.float() + ) loss.backward() return float(loss) diff --git a/elk/training/preprocessing.py b/elk/training/preprocessing.py index 6081dcbb..840de1e7 100644 --- a/elk/training/preprocessing.py +++ b/elk/training/preprocessing.py @@ -43,7 +43,7 @@ def normalize( val_hiddens -= means if method == "elementwise": - scale = 1 / train_hiddens.norm(dim=0, keepdim=True) + scale = 1 / train_hiddens.std(dim=0, keepdim=True) elif method == "meanonly": scale = 1 else: diff --git a/elk/training/reporter.py b/elk/training/reporter.py index b936bb53..c4c09f79 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -165,17 +165,17 @@ def save(self, path: Path | str): @abstractmethod def fit( self, - *hiddens: Tensor, + hiddens: Tensor, labels: Optional[Tensor] = None, ) -> float: ... @abstractmethod - def predict(self, *hiddens: Tensor) -> Tensor: + def predict(self, hiddens: Tensor) -> Tensor: """Return pooled logits for the contrast set `hiddens`.""" @abstractmethod - def predict_prob(self, *hiddens: Tensor) -> Tensor: + def predict_prob(self, hiddens: Tensor) -> Tensor: """Like `predict` but returns normalized probabilities, not logits.""" @torch.no_grad() @@ -198,7 +198,7 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: to_one_hot(Y, n_classes=c).long().flatten() if c == 2: - pos_probs = pred_probs[..., 0].flatten() + pos_probs = pred_probs[..., 1].flatten() cal_err = CalibrationError().update(Y.cpu(), pos_probs.cpu()).compute().ece # Calibrated accuracy diff --git a/elk/training/train.py b/elk/training/train.py index 931a4df5..793a87ab 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -96,7 +96,7 @@ def train_reporter( else: raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") - train_loss = reporter.fit(*train_h.unbind(2), labels=train_gt) + train_loss = reporter.fit(train_h, labels=train_gt) val_result = reporter.score(val_gt.to(device), val_h) reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) From a8c21a6aec422610e05ddba8ac8623310a0a281c Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 11 Apr 2023 22:18:34 +0000 Subject: [PATCH 21/43] Remove .predict and .predict_prob on Reporter; trying to get SciQ to work --- elk/extraction/balanced_sampler.py | 8 +++- elk/extraction/extraction.py | 6 ++- elk/extraction/prompt_loading.py | 49 ++++++++++++++++------- elk/training/ccs_reporter.py | 7 ---- elk/training/eigen_reporter.py | 10 ----- elk/training/reporter.py | 63 +++++++++++++----------------- elk/training/train.py | 45 ++++++++++----------- 7 files changed, 96 insertions(+), 92 deletions(-) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index ec011a47..446914fd 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -29,6 +29,7 @@ class BalancedSampler(TorchIterableDataset): num_classes: int buffer_size: int = 1000 buffers: dict[int, deque[dict]] = field(default_factory=dict, init=False) + label_col: str = "label" def __post_init__(self): # Initialize empty buffers @@ -38,7 +39,12 @@ def __post_init__(self): def __iter__(self): for sample in self.data: - label = sample["label"] + label = sample[self.label_col] + + # This whole class is a no-op if the label is not an integer + if not isinstance(label, int): + yield sample + continue # Add the sample to the buffer for its class label self.buffers[label].append(sample) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 84f43118..69f5de3c 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -99,6 +99,8 @@ def extract_hiddens( prompt_ds = load_prompts( *cfg.prompts.datasets, + label_column=cfg.prompts.label_column, + num_classes=cfg.prompts.num_classes, split_type=split_type, stream=cfg.prompts.stream, rank=rank, @@ -273,8 +275,8 @@ def get_splits() -> SplitDict: info = get_dataset_config_info(ds_name, config_name or None) ds_features = assert_type(Features, info.features) - label_col = infer_label_column(ds_features) - num_classes = infer_num_classes(ds_features[label_col]) + label_col = cfg.prompts.label_column or infer_label_column(ds_features) + num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col]) layer_cols = { f"hidden_{layer}": Array3D( diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 213fd634..80fd80b6 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -50,6 +50,7 @@ class PromptConfig(Serializable): data_dir: Optional[str] = None label_column: Optional[str] = None max_examples: list[int] = field(default_factory=lambda: [750, 250]) + num_classes: int = 0 num_shots: int = 0 num_variants: int = -1 seed: int = 42 @@ -71,6 +72,8 @@ def __post_init__(self): def load_prompts( *dataset_strings: str, + label_column: Optional[str] = None, + num_classes: int = 0, num_shots: int = 0, num_variants: int = -1, seed: int = 42, @@ -98,6 +101,7 @@ def load_prompts( class_counts = [] prompters = [] datasets = [] + label_cols = [] train_datasets = [] rng = Random(seed) assert num_shots == 0 @@ -128,14 +132,11 @@ def load_prompts( # This prints to stdout which is slightly annoying ds = split_dataset_by_node(dataset=ds, rank=rank, world_size=world_size) - label_column = infer_label_column(ds.features) - num_classes = infer_num_classes(ds.features[label_column]) - if label_column != "label": - ds = ds.rename_column(label_column, "label") - train_ds = train_ds.rename_column(label_column, "label") - + ds_label_col = label_column or infer_label_column(ds.features) + num_classes = num_classes or infer_num_classes(ds.features[ds_label_col]) class_counts.append(num_classes) datasets.append(ds) + label_cols.append(ds_label_col) train_datasets.append(train_ds) # Number of classes should be the same for all datasets @@ -155,8 +156,13 @@ def load_prompts( if rank == 0: print(f"Using {num_variants} variants of each prompt") - ds_iters = [iter(BalancedSampler(ds, num_classes)) for ds in datasets] - for ds_iter, ds, prompter in cycle(zip(ds_iters, datasets, prompters)): + ds_iters = [ + iter(BalancedSampler(ds, num_classes, label_col=label_col)) + for ds, label_col in zip(datasets, label_cols) + ] + for ds_iter, ds, label_col, prompter in cycle( + zip(ds_iters, datasets, label_cols, prompters) + ): try: example = next(ds_iter) except StopIteration: @@ -164,7 +170,7 @@ def load_prompts( example = _convert_to_prompts( example, - label_column="label", + label_column=label_col, num_classes=num_classes, num_variants=num_variants, prompter=prompter, @@ -190,7 +196,7 @@ def _convert_to_prompts( fewshot_iter: Optional[Iterator[list[dict]]] = None, ) -> dict[str, Any]: """Prompt-generating function to pass to `IterableDataset.map`.""" - label = assert_type(int, example[label_column]) + labels_are_strings = isinstance(example[label_column], str) prompts = [] templates = list(prompter.templates.values()) if num_variants < len(templates): @@ -203,15 +209,24 @@ def qa_cat(q: str, a: str) -> str: # For sanity checking that prompts are unique prompt_counter = Counter() + label_indices = set() + for template in templates: choices = [] + string_choices = template.get_answer_choices_list(example) + + label = example[label_column] + label_indices.add(string_choices.index(label) if labels_are_strings else label) for answer_idx in range(num_classes): fake_example = example.copy() - fake_example[label_column] = answer_idx + if labels_are_strings: + fake_example[label_column] = string_choices[answer_idx] + else: + fake_example[label_column] = answer_idx q, a = template.apply(fake_example) - text = qa_cat(q, a) + text = qa_cat(q, a or string_choices[answer_idx]) prompt_counter[text] += 1 if fewshot_iter is not None: @@ -238,8 +253,14 @@ def qa_cat(q: str, a: str) -> str: if dup_count > 1: raise ValueError(f'Prompt duplicated {dup_count} times! "{maybe_dup}"') + # Sanity check: label should be the same across all variants + if len(label_indices) > 1: + raise ValueError( + f"Label index should be the same all variants, but got {label_indices}" + ) + return dict( - label=label, + label=label_indices.pop(), prompts=prompts, - template_names=prompter.all_template_names, + template_names=[template.name for template in templates], ) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 0253857e..b7d08259 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -164,13 +164,6 @@ def forward(self, x: Tensor) -> Tensor: """Return the raw score output of the probe on `x`.""" return self.probe(x).squeeze(-1) - def predict(self, hiddens: Tensor) -> Tensor: - return self.predict_prob(hiddens).logit() - - def predict_prob(self, hiddens: Tensor) -> Tensor: - x_pos, x_neg = hiddens.unbind(2) - return 0.5 * (self(x_pos).sigmoid() + (1 - self(x_neg).sigmoid())) - def loss( self, logit0: Tensor, diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index b2d560b3..8c2d4afa 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -103,16 +103,6 @@ def forward(self, hiddens: Tensor) -> Tensor: raw_scores = hiddens @ self.weight.mT return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - predict = forward - - def predict_prob(self, hiddens: Tensor) -> Tensor: - """Return the predicted probabilities on the contrast set `hiddens`.""" - logits = self(hiddens) - if len(hiddens) == 2: - return logits.sigmoid() - else: - return logits.softmax(dim=-1) - @property def contrastive_xcov(self) -> Tensor: return self.contrastive_xcov_M2 / self.n diff --git a/elk/training/reporter.py b/elk/training/reporter.py index c4c09f79..9d33e02a 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -13,7 +13,7 @@ from torch import Tensor from ..calibration import CalibrationError -from ..metrics import mean_auc, to_one_hot +from ..metrics import to_one_hot from .classifier import Classifier @@ -96,42 +96,43 @@ def check_separability( """Measure how linearly separable the pseudo-labels are for a contrast pair. Args: - train_hiddens: Tensor of shape [n, ], where x0 and x1 are the - contrastive representations. Used for training the classifier. - val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the - contrastive representations. Used for evaluating the classifier. + train_hiddens: Contrast set of shape [n, v, k, d]. Used for training the + classifier. + val_hiddens: Contrast set of shape [n, v, k, d]. Used for evaluating the + classifier. Returns: The AUROC of a linear classifier fit on the pseudo-labels. """ - x0, x1 = train_hiddens - val_x0, val_x1 = val_hiddens + (n_train, v, k, d) = train_hiddens.shape + (n_val, _, k_val, d_val) = val_hiddens.shape + assert d == d_val, "Must have the same number of features in each split" + assert k == k_val == 2, "Must be a binary contrast set" - pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore + pseudo_clf = Classifier(d, device=train_hiddens.device) pseudo_train_labels = torch.cat( [ - x0.new_zeros(x0.shape[0]), - x0.new_ones(x0.shape[0]), + train_hiddens.new_zeros(n_train), + train_hiddens.new_ones(n_train), ] ).repeat_interleave( - x0.shape[1] + v ) # make num_variants copies of each pseudo-label + pseudo_val_labels = torch.cat( [ - val_x0.new_zeros(val_x0.shape[0]), - val_x0.new_ones(val_x0.shape[0]), + val_hiddens.new_zeros(n_val), + val_hiddens.new_ones(n_val), ] - ).repeat_interleave(val_x0.shape[1]) + ).repeat_interleave(v) pseudo_clf.fit( - # b v d -> (b v) d - torch.cat([x0, x1]).flatten(0, 1), + rearrange(train_hiddens, "n v k d -> (k n v) d"), pseudo_train_labels, ) with torch.no_grad(): pseudo_preds = pseudo_clf( - # b v d -> (b v) d - torch.cat([val_x0, val_x1]).flatten(0, 1) + rearrange(val_hiddens, "n v k d -> (k n v) d"), ) return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu())) @@ -170,35 +171,27 @@ def fit( ) -> float: ... - @abstractmethod - def predict(self, hiddens: Tensor) -> Tensor: - """Return pooled logits for the contrast set `hiddens`.""" - - @abstractmethod - def predict_prob(self, hiddens: Tensor) -> Tensor: - """Like `predict` but returns normalized probabilities, not logits.""" - @torch.no_grad() def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: """Score the probe on the contrast set `hiddens`. Args: labels: The labels of the contrast pair. - hiddens: The hidden representations of the contrast set. + hiddens: Contrast set of shape [n, v, k, d]. Returns: an instance of EvalResult containing the loss, accuracy, calibrated accuracy, and AUROC of the probe on `hiddens`. """ - pred_probs = self.predict_prob(hiddens) - (_, v, c) = pred_probs.shape + logits = self(hiddens) + (_, v, c) = logits.shape # makes `num_variants` copies of each label + logits = rearrange(logits, "n v c -> (n v) c") Y = repeat(labels, "n -> (n v)", v=v).float() - to_one_hot(Y, n_classes=c).long().flatten() if c == 2: - pos_probs = pred_probs[..., 1].flatten() + pos_probs = logits[..., 1].flatten().sigmoid() cal_err = CalibrationError().update(Y.cpu(), pos_probs.cpu()).compute().ece # Calibrated accuracy @@ -212,12 +205,10 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: cal_acc = 0.0 cal_err = 0.0 - raw_preds = pred_probs.argmax(dim=-1) + raw_preds = to_one_hot(logits.argmax(dim=-1), c).long() + Y = to_one_hot(Y, c).long().flatten() - # roc_auc_score only takes flattened input - auroc = mean_auc( - Y.cpu(), rearrange(pred_probs.cpu(), "n v ... -> (n v) ..."), curve="roc" - ) + auroc = roc_auc_score(Y.cpu(), logits.cpu().flatten()) raw_acc = raw_preds.flatten().eq(Y).float().mean() return EvalResult( diff --git a/elk/training/train.py b/elk/training/train.py index 793a87ab..35d29961 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,5 +1,6 @@ """Main training loop.""" +import warnings from dataclasses import dataclass from functools import partial from pathlib import Path @@ -9,16 +10,18 @@ import torch from einops import rearrange, repeat from simple_parsing import Serializable, field, subgroups +from sklearn.metrics import roc_auc_score +from torch import Tensor from ..extraction.extraction import Extract -from ..metrics import accuracy, mean_auc +from ..metrics import accuracy, to_one_hot from ..run import Run from ..training.baseline import evaluate_baseline, save_baseline, train_baseline from ..utils import select_usable_devices from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig -from .reporter import OptimConfig, ReporterConfig +from .reporter import OptimConfig, Reporter, ReporterConfig @dataclass @@ -87,7 +90,7 @@ def train_reporter( device, layer ) (_, v, c, d) = train_h.shape - # pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1) + pseudo_auroc = self.get_pseudo_auroc(layer, train_h, val_h) if isinstance(self.cfg.net, CcsReporterConfig): reporter = CcsReporter(self.cfg.net, d, device=device) @@ -103,7 +106,9 @@ def train_reporter( if val_lm_preds is not None: val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu() val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...") - val_lm_auroc = mean_auc(val_gt_cpu, val_lm_preds, "roc") + val_lm_auroc = roc_auc_score( + to_one_hot(val_gt_cpu, c).long().flatten(), val_lm_preds.cpu().flatten() + ) val_lm_acc = accuracy(val_gt_cpu, val_lm_preds) else: @@ -113,7 +118,7 @@ def train_reporter( row = pd.Series( { "layer": layer, - # "pseudo_auroc": pseudo_auroc, + "pseudo_auroc": pseudo_auroc, "train_loss": train_loss, **val_result._asdict(), "lm_auroc": val_lm_auroc, @@ -134,23 +139,19 @@ def train_reporter( return row - # def get_pseudo_auroc( - # self, layer: int, train_h: Tensor, val_h: Tensor - # ): - # """Check the separability of the pseudo-labels at a given layer.""" - # - # with torch.no_grad(): - # pseudo_auroc = Reporter.check_separability( - # train_pair=(x0, x1), val_pair=(val_x0, val_x1) - # ) - # if pseudo_auroc > 0.6: - # warnings.warn( - # f"The pseudo-labels at layer {layer} are linearly separable with " - # f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the " - # f"algorithm will not converge to a good solution." - # ) - # - # return pseudo_auroc + def get_pseudo_auroc(self, layer: int, train_h: Tensor, val_h: Tensor): + """Check the separability of the pseudo-labels at a given layer.""" + + with torch.no_grad(): + pseudo_auroc = Reporter.check_separability(train_h, val_h) + if pseudo_auroc > 0.6: + warnings.warn( + f"The pseudo-labels at layer {layer} are linearly separable with " + f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the " + f"algorithm will not converge to a good solution." + ) + + return pseudo_auroc def train(self): """Train a reporter on each layer of the network.""" From 5f478b1bbd426405198add45294037e62541e07e Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 11 Apr 2023 22:45:22 +0000 Subject: [PATCH 22/43] Bugfix for Reporter.score on binary tasks --- elk/training/reporter.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 9d33e02a..3fa37630 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -198,17 +198,15 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: cal_thresh = pos_probs.float().quantile(labels.float().mean()) cal_preds = pos_probs.gt(cal_thresh).to(torch.int) cal_acc = cal_preds.flatten().eq(Y).float().mean().item() - - raw_preds = pos_probs.gt(0.5).to(torch.int) else: # TODO: Implement calibration error for k > 2? cal_acc = 0.0 cal_err = 0.0 - raw_preds = to_one_hot(logits.argmax(dim=-1), c).long() - Y = to_one_hot(Y, c).long().flatten() - - auroc = roc_auc_score(Y.cpu(), logits.cpu().flatten()) + raw_preds = to_one_hot(logits.argmax(dim=-1), c).long() + auroc = roc_auc_score( + to_one_hot(Y, c).long().flatten().cpu(), logits.cpu().flatten() + ) raw_acc = raw_preds.flatten().eq(Y).float().mean() return EvalResult( From 97b26aca802b95f2f28c00e2a51a5b023884c398 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 12 Apr 2023 05:44:34 +0000 Subject: [PATCH 23/43] =?UTF-8?q?Fix=20bug=20where=20cached=20hidden=20sta?= =?UTF-8?q?tes=20aren=E2=80=99t=20used=20when=20num=5Fgpus=20is=20differen?= =?UTF-8?q?t?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- elk/extraction/generator.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index fb4d03bc..5aa63bd8 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, Optional import datasets +from datasets import Features from datasets.splits import NamedSplit @@ -11,6 +12,20 @@ class _GeneratorConfig(datasets.BuilderConfig): gen_kwargs: dict[str, Any] = field(default_factory=dict) features: Optional[datasets.Features] = None + def create_config_id( + self, config_kwargs: dict, custom_features: Features | None + ) -> str: + # These are implementation details that don't need to be hashed into the id + # TODO: Make this customizable or something? Right now we're just hard-coding + # these values from extraction.py. OTOH maybe it's not terrible because this is + # a private class anyway. + gen_kwargs = config_kwargs.get("gen_kwargs", {}) + gen_kwargs.pop("device", None) + gen_kwargs.pop("rank", None) + gen_kwargs.pop("world_size", None) + + return super().create_config_id(config_kwargs, custom_features) + @dataclass class _SplitGenerator: From 11fda875e803967e1d4aa268b783b0500e775cde Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 12 Apr 2023 09:17:11 +0000 Subject: [PATCH 24/43] Actually works now --- elk/extraction/extraction.py | 15 +++++++++------ elk/extraction/generator.py | 21 ++++++++++++--------- elk/training/eigen_reporter.py | 2 +- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 74f4e489..a9abd6a6 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,6 +1,7 @@ """Functions for extracting the hidden states of a model.""" import logging import os +from copy import copy from dataclasses import InitVar, dataclass from itertools import islice from typing import Any, Iterable, Literal, Optional @@ -22,7 +23,6 @@ from transformers import AutoConfig, AutoTokenizer from transformers.modeling_outputs import Seq2SeqLMOutput -# import torch.nn.functional as F from ..utils import ( assert_type, convert_span, @@ -87,10 +87,7 @@ def extract_hiddens( rank: int = 0, world_size: int = 1, ) -> Iterable[dict]: - """Run inference on a model with a set of prompts, yielding the hidden states. - - This is a lightweight, functional version of the `Extractor` API. - """ + """Run inference on a model with a set of prompts, yielding the hidden states.""" os.environ["TOKENIZERS_PARALLELISM"] = "false" # Silence datasets logging messages from all but the first process @@ -301,6 +298,12 @@ def get_splits() -> SplitDict: ) devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem) + + # Prevent the GPU-related config options from invalidating the cache + _cfg = copy(cfg) + _cfg.min_gpu_mem = None + _cfg.num_gpus = -1 + builders = { split_name: _GeneratorBuilder( cache_dir=None, @@ -309,7 +312,7 @@ def get_splits() -> SplitDict: split_name=split_name, split_info=split_info, gen_kwargs=dict( - cfg=[cfg] * len(devices), + cfg=[_cfg] * len(devices), device=devices, rank=list(range(len(devices))), split_type=[split_name] * len(devices), diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 5aa63bd8..e3cad0e5 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -1,3 +1,4 @@ +from copy import deepcopy from dataclasses import dataclass, field from typing import Any, Callable, Dict, Optional @@ -15,15 +16,17 @@ class _GeneratorConfig(datasets.BuilderConfig): def create_config_id( self, config_kwargs: dict, custom_features: Features | None ) -> str: - # These are implementation details that don't need to be hashed into the id - # TODO: Make this customizable or something? Right now we're just hard-coding - # these values from extraction.py. OTOH maybe it's not terrible because this is - # a private class anyway. - gen_kwargs = config_kwargs.get("gen_kwargs", {}) - gen_kwargs.pop("device", None) - gen_kwargs.pop("rank", None) - gen_kwargs.pop("world_size", None) - + config_kwargs = deepcopy(config_kwargs) + + # By default the values in gen_kwargs are lists of length world_size. We want + # to erase the world_size dimension so that the config id is the same no matter + # how many processes are used. We also remove the explicit device, rank, and + # world_size keys. + config_kwargs["gen_kwargs"] = { + k: v[0] + for k, v in config_kwargs.get("gen_kwargs", {}).items() + if k not in ("device", "rank", "world_size") + } return super().create_config_id(config_kwargs, custom_features) diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 1d2f6a98..821891cc 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -180,7 +180,7 @@ def fit_streaming(self) -> float: try: L, Q = truncated_eigh(A, k=self.config.num_heads) - except ConvergenceError: + except (ConvergenceError, RuntimeError): warn( "Truncated eigendecomposition failed to converge. Falling back on " "PyTorch's dense eigensolver." From da4c72f231be63337f46122d32f01633a7c8a689 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 13 Apr 2023 06:06:35 +0000 Subject: [PATCH 25/43] Refactor handling of multiple datasets --- elk/__main__.py | 5 +- elk/evaluation/evaluate.py | 60 ++++++----- elk/extraction/balanced_sampler.py | 2 +- elk/extraction/extraction.py | 44 ++++---- elk/extraction/generator.py | 42 +++++--- elk/extraction/prompt_loading.py | 167 +++++++++++++---------------- elk/files.py | 28 ----- elk/run.py | 125 +++++++++++---------- elk/training/__init__.py | 6 +- elk/training/baseline.py | 14 --- elk/training/ccs_reporter.py | 18 +++- elk/training/eigen_reporter.py | 18 +++- elk/training/normalizer.py | 63 +++++++++++ elk/training/preprocessing.py | 55 ---------- elk/training/reporter.py | 32 +----- elk/training/train.py | 89 ++++++++------- elk/utils/__init__.py | 4 + elk/utils/data_utils.py | 21 +++- elk/{ => utils}/math_util.py | 0 tests/test_smoke_elicit.py | 4 +- 20 files changed, 406 insertions(+), 391 deletions(-) create mode 100644 elk/training/normalizer.py delete mode 100644 elk/training/preprocessing.py rename elk/{ => utils}/math_util.py (100%) diff --git a/elk/__main__.py b/elk/__main__.py index 5304f5aa..faa044f7 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -5,7 +5,6 @@ from simple_parsing import ArgumentParser from elk.evaluation.evaluate import Eval -from elk.extraction.extraction import Extract from elk.training.train import Elicit @@ -13,14 +12,14 @@ class Command: """Some top-level command""" - command: Elicit | Eval | Extract + command: Elicit | Eval def execute(self): return self.command.execute() def run(): - parser = ArgumentParser(add_help=False) + parser = ArgumentParser(add_help=False, add_config_path_arg=True) parser.add_arguments(Command, dest="run") args = parser.parse_args() run: Command = args.run diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 6aca58f5..ad81182a 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Literal, Optional +from typing import Callable import pandas as pd import torch @@ -11,7 +11,7 @@ from ..files import elk_reporter_dir from ..run import Run from ..training import Reporter -from ..training.baseline import evaluate_baseline, load_baseline +from ..training.baseline import evaluate_baseline from ..utils import select_usable_devices @@ -34,11 +34,11 @@ class Eval(Serializable): data: Extract source: str = field(positional=True) - normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" debug: bool = False - out_dir: Optional[Path] = None + out_dir: Path | None = None num_gpus: int = -1 + min_gpu_mem: int | None = None skip_baseline: bool = False concatenated_layer_offset: int = 0 @@ -58,11 +58,11 @@ class Evaluate(Run): def evaluate_reporter( self, layer: int, devices: list[str], world_size: int = 1 - ) -> pd.Series: + ) -> pd.DataFrame: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) - _, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data( + _, _, _, val_output = self.prepare_data( device, layer, ) @@ -73,40 +73,42 @@ def evaluate_reporter( reporter: Reporter = torch.load(reporter_path, map_location=device) reporter.eval() - test_result = reporter.score( - test_labels, - test_x0, - test_x1, - ) - - stats_row = pd.Series( - { - "layer": layer, - **test_result._asdict(), - } - ) + row_buf = [] + for ds_name, (val_x0, val_x1, val_gt, _) in val_output.items(): + val_result = reporter.score( + val_gt, + val_x0, + val_x1, + ) - lr_dir = experiment_dir / "lr_models" - if not self.cfg.skip_baseline and lr_dir.exists(): - lr_model = load_baseline(lr_dir, layer) - lr_model.eval() - lr_auroc, lr_acc = evaluate_baseline( - lr_model.cuda(), test_x0.cuda(), test_x1.cuda(), test_labels + stats_row = pd.Series( + { + "dataset": ds_name, + "layer": layer, + **val_result._asdict(), + } ) - stats_row["lr_auroc"] = lr_auroc - stats_row["lr_acc"] = lr_acc + lr_dir = experiment_dir / "lr_models" + if not self.cfg.skip_baseline and lr_dir.exists(): + with open(lr_dir / f"layer_{layer}.pt", "rb") as f: + lr_model = torch.load(f, map_location=device).eval() + + lr_auroc, lr_acc = evaluate_baseline(lr_model, val_x0, val_x1, val_gt) + + stats_row["lr_auroc"] = lr_auroc + stats_row["lr_acc"] = lr_acc - return stats_row + return pd.DataFrame(row_buf) def evaluate(self): """Evaluate the reporter on all layers.""" devices = select_usable_devices( - self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem + self.cfg.num_gpus, min_memory=self.cfg.min_gpu_mem ) num_devices = len(devices) - func: Callable[[int], pd.Series] = partial( + func: Callable[[int], pd.DataFrame] = partial( self.evaluate_reporter, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 2ea4815e..4a287320 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -6,8 +6,8 @@ from datasets import Features, IterableDataset from torch.utils.data import IterableDataset as TorchIterableDataset -from ..math_util import stochastic_round_constrained from ..utils import infer_label_column +from ..utils.math_util import stochastic_round_constrained from ..utils.typing import assert_type diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index a9abd6a6..03105c74 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -4,7 +4,7 @@ from copy import copy from dataclasses import InitVar, dataclass from itertools import islice -from typing import Any, Iterable, Literal, Optional +from typing import Any, Iterable, Literal import torch from datasets import ( @@ -48,7 +48,6 @@ class Extract(Serializable): layer_stride: Shortcut for setting `layers` to `range(0, num_layers, stride)`. token_loc: The location of the token to extract hidden states from. Can be either "first", "last", or "mean". Defaults to "last". - min_gpu_mem: Minimum amount of free memory (in bytes) required to select a GPU. """ prompts: PromptConfig @@ -57,8 +56,6 @@ class Extract(Serializable): layers: tuple[int, ...] = () layer_stride: InitVar[int] = 1 token_loc: Literal["first", "last", "mean"] = "last" - min_gpu_mem: Optional[int] = None - num_gpus: int = -1 def __post_init__(self, layer_stride: int): if self.layers and layer_stride > 1: @@ -74,8 +71,16 @@ def __post_init__(self, layer_stride: int): ) self.layers = tuple(range(0, config.num_hidden_layers, layer_stride)) - def execute(self): - extract(cfg=self, num_gpus=self.num_gpus) + def explode(self) -> list["Extract"]: + """Explode this config into a list of configs, one for each layer.""" + copies = [] + + for prompt_cfg in self.prompts.explode(): + cfg = copy(self) + cfg.prompts = prompt_cfg + copies.append(cfg) + + return copies @torch.no_grad() @@ -94,8 +99,11 @@ def extract_hiddens( if rank != 0: logging.disable(logging.CRITICAL) + ds_names = cfg.prompts.datasets + assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." + prompt_ds = load_prompts( - *cfg.prompts.datasets, + ds_names[0], split_type=split_type, stream=cfg.prompts.stream, rank=rank, @@ -240,14 +248,18 @@ def _extraction_worker(**kwargs): yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) -def extract(cfg: "Extract", num_gpus: int = -1) -> DatasetDict: +def extract( + cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None +) -> DatasetDict: """Extract hidden states from a model and return a `DatasetDict` containing them.""" def get_splits() -> SplitDict: available_splits = assert_type(SplitDict, info.splits) train_name, val_name = select_train_val_splits(available_splits) - print(f"Using '{train_name}' for training and '{val_name}' for validation") - + print( + f"{info.builder_name}: using '{train_name}' for training and '{val_name}'" + f" for validation" + ) limit_list = cfg.prompts.max_examples return SplitDict( @@ -297,22 +309,18 @@ def get_splits() -> SplitDict: length=num_variants, ) - devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem) - - # Prevent the GPU-related config options from invalidating the cache - _cfg = copy(cfg) - _cfg.min_gpu_mem = None - _cfg.num_gpus = -1 - + devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem) builders = { split_name: _GeneratorBuilder( + builder_name=info.builder_name, + config_name=info.config_name, cache_dir=None, features=Features({**layer_cols, **other_cols}), generator=_extraction_worker, split_name=split_name, split_info=split_info, gen_kwargs=dict( - cfg=[_cfg] * len(devices), + cfg=[cfg] * len(devices), device=devices, rank=list(range(len(devices))), split_type=[split_name] * len(devices), diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index e3cad0e5..86e65e08 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -1,17 +1,22 @@ from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional - -import datasets -from datasets import Features +from typing import Any, Callable + +from datasets import ( + BuilderConfig, + DatasetInfo, + Features, + GeneratorBasedBuilder, + SplitInfo, +) from datasets.splits import NamedSplit @dataclass -class _GeneratorConfig(datasets.BuilderConfig): - generator: Optional[Callable] = None +class _GeneratorConfig(BuilderConfig): + generator: Callable | None = None gen_kwargs: dict[str, Any] = field(default_factory=dict) - features: Optional[datasets.Features] = None + features: Features | None = None def create_config_id( self, config_kwargs: dict, custom_features: Features | None @@ -37,28 +42,41 @@ class _SplitGenerator: """ name: str - split_info: datasets.SplitInfo - gen_kwargs: Dict = field(default_factory=dict) + split_info: SplitInfo + gen_kwargs: dict = field(default_factory=dict) def __post_init__(self): self.name = str(self.name) # Make sure we convert NamedSplits in strings NamedSplit(self.name) # check that it's a valid split name -class _GeneratorBuilder(datasets.GeneratorBasedBuilder): +class _GeneratorBuilder(GeneratorBasedBuilder): """Patched version of `datasets.Generator` allowing for splits besides `train`""" BUILDER_CONFIG_CLASS = _GeneratorConfig config: _GeneratorConfig - def __init__(self, split_name: str, split_info: datasets.SplitInfo, **kwargs): + def __init__( + self, + builder_name: str | None, + config_name: str | None, + split_name: str, + split_info: SplitInfo, + **kwargs, + ): self.split_name = split_name self.split_info = split_info super().__init__(**kwargs) + # Weirdly we need to set DatasetInfo.builder_name and DatasetInfo.config_name + # here, not in _info, because super().__init__ modifies them + self.info.builder_name = builder_name + self.info.config_name = config_name + def _info(self): - return datasets.DatasetInfo(features=self.config.features) + # Use the same builder and config name as the original builder + return DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): return [ diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index c46d3d7b..56be3baf 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -1,5 +1,7 @@ from collections import Counter +from copy import deepcopy from dataclasses import dataclass +from itertools import zip_longest from random import Random from typing import Any, Iterator, Literal, Optional @@ -26,9 +28,8 @@ class PromptConfig(Serializable): """ Args: - dataset: Space-delimited name of the HuggingFace dataset to use, e.g. + dataset: List of space-delimited names of the HuggingFace dataset to use, e.g. `"super_glue boolq"` or `"imdb"`. - balance: Whether to force class balance in the dataset using undersampling. data_dir: The directory to use for caching the dataset. Defaults to `~/.cache/huggingface/datasets`. label_column: The column containing the labels. By default, we infer this from @@ -47,9 +48,8 @@ class PromptConfig(Serializable): """ datasets: list[str] = field(positional=True) - balance: bool = False - data_dir: Optional[str] = None - label_column: Optional[str] = None + data_dirs: list[str] = field(default_factory=list) + label_columns: list[str] = field(default_factory=list) max_examples: list[int] = field(default_factory=lambda: [750, 250]) num_shots: int = 0 num_variants: int = -1 @@ -69,9 +69,30 @@ def __post_init__(self): if len(self.max_examples) == 1: self.max_examples *= 2 + def explode(self) -> list["PromptConfig"]: + """Explode the config into a list of configs, one for each dataset.""" + copies = [] + + # Broadcast the dataset name to all data_dirs and label_columns + if len(self.data_dirs) == 1: + self.data_dirs *= len(self.datasets) + if len(self.label_columns) == 1: + self.label_columns *= len(self.datasets) + + for ds, data_dir, col in zip_longest( + self.datasets, self.data_dirs, self.label_columns + ): + copy = deepcopy(self) + copy.datasets = [ds] + copy.data_dirs = [data_dir] if data_dir else [] + copy.label_columns = [col] if col else [] + copies.append(copy) + + return copies + def load_prompts( - *dataset_strings: str, + ds_string: str, num_shots: int = 0, num_variants: int = -1, seed: int = 42, @@ -83,7 +104,7 @@ def load_prompts( """Load a dataset full of prompts generated from the specified datasets. Args: - dataset_strings: Space-delimited names of the HuggingFace datasets to use, + ds_string: Space-delimited name of the HuggingFace datasets to use, e.g. `"super_glue boolq"` or `"imdb"`. num_shots: The number of examples to use in few-shot prompts. If zero, prompts are zero-shot. @@ -96,101 +117,65 @@ def load_prompts( Returns: An iterable dataset of prompts. """ - prompters = [] - raw_datasets = [] - train_datasets = [] - rng = Random(seed) + ds_name, _, config_name = ds_string.partition(" ") + prompter = DatasetTemplates(ds_name, config_name) - # First load the datasets and prompters. We need to know the minimum number of - # templates for any dataset in order to make sure we don't run out of prompts. - for ds_string in dataset_strings: - ds_name, _, config_name = ds_string.partition(" ") - prompters.append(DatasetTemplates(ds_name, config_name)) + ds_dict = assert_type( + dict, load_dataset(ds_name, config_name or None, streaming=stream) + ) + train_name, val_name = select_train_val_splits(ds_dict) + split_name = val_name if split_type == "val" else train_name - ds_dict = assert_type( - dict, load_dataset(ds_name, config_name or None, streaming=stream) - ) - train_name, val_name = select_train_val_splits(ds_dict) - split_name = val_name if split_type == "val" else train_name - - # Note that when streaming we can only approximately shuffle the dataset - # using a buffer. Streaming shuffling is NOT an adequate shuffle for - # datasets like IMDB, which are sorted by label. - bad_streaming_datasets = ["imdb"] - assert not ( - stream and ds_name in bad_streaming_datasets - ), f"Streaming is not supported for {ds_name}." - split = ds_dict[split_name].shuffle(seed=seed) - train_ds = ds_dict[train_name].shuffle(seed=seed) - if not stream: - split = assert_type(Dataset, split) - split = split.to_iterable_dataset().cast(split.features) - - # only keep the datapoints relevant to the current process - if world_size > 1: - # This prints to stdout which is slightly annoying - split = split_dataset_by_node( - dataset=split, rank=rank, world_size=world_size - ) + ds = ds_dict[split_name].shuffle(seed=seed) + train_ds = ds_dict[train_name].shuffle(seed=seed) + if not stream: + ds = assert_type(Dataset, ds) + ds = ds.to_iterable_dataset().cast(ds.features) - raw_datasets.append(split) - train_datasets.append(train_ds) + # only keep the datapoints relevant to the current process + if world_size > 1: + # This prints to stdout which is slightly annoying + ds = split_dataset_by_node(dataset=ds, rank=rank, world_size=world_size) - min_num_templates = min(len(prompter.templates) for prompter in prompters) + num_templates = len(prompter.templates) num_variants = ( - min_num_templates - if num_variants == -1 - else min(num_variants, min_num_templates) + num_templates if num_variants == -1 else min(num_variants, num_templates) ) assert num_variants > 0 if rank == 0: print(f"Using {num_variants} variants of each prompt") - ds_iterators = [iter(ds) for ds in raw_datasets] - while True: # terminates when the first dataset runs out of examples - for ds_iterator, ds, train_ds, prompter in zip( - ds_iterators, raw_datasets, train_datasets, prompters - ): - label_column = infer_label_column(ds.features) - num_classes = infer_num_classes(ds.features[label_column]) - - # Remove everything except the label column - extra_cols = list(assert_type(Features, ds.features)) - extra_cols.remove(label_column) - - if label_column != "label": - ds = ds.rename_column(label_column, "label") - if num_shots > 0: - fewshot = FewShotSampler( - train_ds, # TODO: not iterator - num_shots=num_shots, - rng=rng, - ) - fewshot_iter = iter(fewshot) - else: - fewshot_iter = None - - try: - example = next(ds_iterator) - except StopIteration: - return - - example = _convert_to_prompts( - example, - label_column=label_column, - num_classes=num_classes, - num_variants=num_variants, - prompter=prompter, - rng=rng, - fewshot_iter=fewshot_iter, - ) - - # Add the builder and config name to the records directly to make - # sure we don't forget what dataset they came from. - example["builder_name"] = ds.info.builder_name - example["config_name"] = ds.info.config_name + label_column = infer_label_column(ds.features) + num_classes = infer_num_classes(ds.features[label_column]) + rng = Random(seed) - yield example + if num_shots > 0: + fewshot = FewShotSampler( + train_ds, # TODO: not iterator + num_shots=num_shots, + rng=rng, + ) + fewshot_iter = iter(fewshot) + else: + fewshot_iter = None + + # Remove everything except the label column + extra_cols = list(assert_type(Features, ds.features)) + extra_cols.remove(label_column) + + if label_column != "label": + ds = ds.rename_column(label_column, "label") + + for example in ds: + yield _convert_to_prompts( + example, + label_column=label_column, + num_classes=num_classes, + num_variants=num_variants, + prompter=prompter, + rng=rng, + fewshot_iter=fewshot_iter, + ) def _convert_to_prompts( diff --git a/elk/files.py b/elk/files.py index a9da71e9..4435dee1 100644 --- a/elk/files.py +++ b/elk/files.py @@ -5,9 +5,6 @@ import random from pathlib import Path -import yaml -from simple_parsing import Serializable - def elk_reporter_dir() -> Path: """Return the directory where reporter checkpoints and logs are stored.""" @@ -41,28 +38,3 @@ def memorably_named_dir(parent: Path): out_dir = parent / sub_dir out_dir.mkdir(parents=True, exist_ok=True) return out_dir - - -def save_config(cfg: Serializable, out_dir: Path): - """Save the config to a file""" - - path = out_dir / "cfg.yaml" - with open(path, "w") as f: - cfg.dump_yaml(f) - - return path - - -def save_meta(dataset, out_dir: Path): - """Save the meta data to a file""" - - meta = { - "dataset_fingerprints": { - split: dataset[split]._fingerprint for split in dataset.keys() - } - } - path = out_dir / "metadata.yaml" - with open(path, "w") as meta_f: - yaml.dump(meta, meta_f) - - return path diff --git a/elk/run.py b/elk/run.py index af75c597..3679e32c 100644 --- a/elk/run.py +++ b/elk/run.py @@ -3,27 +3,27 @@ from abc import ABC from dataclasses import dataclass, field from pathlib import Path -from typing import ( - TYPE_CHECKING, - Callable, - Optional, - Union, -) +from typing import TYPE_CHECKING, Callable, Union import numpy as np import pandas as pd import torch import torch.multiprocessing as mp -from datasets import DatasetDict +import yaml +from datasets import DatasetDict, concatenate_datasets from torch import Tensor from tqdm import tqdm from .extraction import extract -from .files import elk_reporter_dir, memorably_named_dir, save_config, save_meta +from .files import elk_reporter_dir, memorably_named_dir from .logging import save_debug_log -from .training.preprocessing import normalize -from .utils import assert_type, int16_to_float32 -from .utils.data_utils import get_layers, select_train_val_splits +from .utils import ( + assert_type, + get_dataset_name, + get_layers, + int16_to_float32, + select_train_val_splits, +) if TYPE_CHECKING: from .evaluation.evaluate import Eval @@ -33,12 +33,14 @@ @dataclass class Run(ABC): cfg: Union["Elicit", "Eval"] - out_dir: Optional[Path] = None - dataset: DatasetDict = field(init=False) + out_dir: Path | None = None + datasets: list[DatasetDict] = field(init=False) def __post_init__(self): - # Extract the hidden states first if necessary - self.dataset = extract(self.cfg.data, num_gpus=self.cfg.num_gpus) + self.datasets = [ + extract(cfg, num_gpus=self.cfg.num_gpus, min_gpu_mem=self.cfg.min_gpu_mem) + for cfg in self.cfg.data.explode() + ] if self.out_dir is None: # Save in a memorably-named directory inside of @@ -52,8 +54,21 @@ def __post_init__(self): print(f"Output directory at \033[1m{self.out_dir}\033[0m") self.out_dir.mkdir(parents=True, exist_ok=True) - save_config(self.cfg, self.out_dir) - save_meta(self.dataset, self.out_dir) + path = self.out_dir / "cfg.yaml" + with open(path, "w") as f: + self.cfg.dump_yaml(f) + + path = self.out_dir / "fingerprints.yaml" + with open(path, "w") as meta_f: + yaml.dump( + { + get_dataset_name(ds): { + split: ds[split]._fingerprint for split in ds.keys() + } + for ds in self.datasets + }, + meta_f, + ) def make_reproducible(self, seed: int): """Make the run reproducible by setting the random seed.""" @@ -69,56 +84,54 @@ def get_device(self, devices, world_size: int) -> str: device = devices[rank] return device - def prepare_data( - self, - device: str, - layer: int, - ) -> tuple: + def prepare_data(self, device: str, layer: int) -> tuple: """Prepare the data for training and validation.""" - with self.dataset.formatted_as("torch", device=device, dtype=torch.int16): - train_split, val_split = select_train_val_splits(self.dataset) - train, val = self.dataset[train_split], self.dataset[val_split] + train_sets = [] + val_output = {} + + # We handle train and val differently. We want to concatenate all of the + # train sets together, but we want to keep the val sets separate so that we can + # compute evaluation metrics separately for each dataset. + for ds in self.datasets: + train_split, val_split = select_train_val_splits(ds) + train_sets.append(ds[train_split]) - train_labels = assert_type(Tensor, train["label"]) + val = ds[val_split].with_format("torch", device=device, dtype=torch.int16) val_labels = assert_type(Tensor, val["label"]) + val_h = int16_to_float32(assert_type(torch.Tensor, val[f"hidden_{layer}"])) + val_x0, val_x1 = val_h.unbind(dim=-2) - # Note: currently we're just upcasting to float32 - # so we don't have to deal with - # grad scaling (which isn't supported for LBFGS), - # while the hidden states are - # saved in float16 to save disk space. - # In the future we could try to use mixed - # precision training in at least some cases. - train_h, val_h = normalize( - int16_to_float32(assert_type(torch.Tensor, train[f"hidden_{layer}"])), - int16_to_float32(assert_type(torch.Tensor, val[f"hidden_{layer}"])), - method=self.cfg.normalization, - ) + with val.formatted_as("numpy"): + has_preds = "model_preds" in val.features + val_lm_preds = val["model_preds"] if has_preds else None - x0, x1 = train_h.unbind(dim=-2) - val_x0, val_x1 = val_h.unbind(dim=-2) + ds_name = get_dataset_name(ds) + val_output[ds_name] = (val_x0, val_x1, val_labels, val_lm_preds) - with self.dataset.formatted_as("numpy"): - has_preds = "model_preds" in val.features - val_lm_preds = val["model_preds"] if has_preds else None + train = concatenate_datasets(train_sets).with_format( + "torch", device=device, dtype=torch.int16 + ) - return x0, x1, val_x0, val_x1, train_labels, val_labels, val_lm_preds + train_labels = assert_type(Tensor, train["label"]) + train_h = int16_to_float32(assert_type(torch.Tensor, train[f"hidden_{layer}"])) + x0, x1 = train_h.unbind(dim=-2) + + return x0, x1, train_labels, val_output def concatenate(self, layers): """Concatenate hidden states from a previous layer.""" for layer in range(self.cfg.concatenated_layer_offset, len(layers)): - layers[layer] = layers[layer] + [ - layers[layer][0] - self.cfg.concatenated_layer_offset - ] + layers[layer] += [layers[layer][0] - self.cfg.concatenated_layer_offset] + return layers def apply_to_layers( self, - func: Callable[[int], pd.Series], + func: Callable[[int], pd.DataFrame], num_devices: int, ): - """Apply a function to each layer of the dataset in parallel + """Apply a function to each layer of the datasets in parallel and writes the results to a CSV file. Args: @@ -128,7 +141,8 @@ def apply_to_layers( """ self.out_dir = assert_type(Path, self.out_dir) - layers: list[int] = get_layers(self.dataset) + layers, *rest = [get_layers(ds) for ds in self.datasets] + assert all(x == layers for x in rest), "All datasets must have the same layers" if self.cfg.concatenated_layer_offset > 0: layers = self.concatenate(layers) @@ -137,14 +151,15 @@ def apply_to_layers( ctx = mp.get_context("spawn") with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: mapper = pool.imap_unordered if num_devices > 1 else map - row_buf = [] + df_buf = [] try: for row in tqdm(mapper(func, layers), total=len(layers)): - row_buf.append(row) + df_buf.append(row) finally: # Make sure the CSV is written even if we crash or get interrupted - df = pd.DataFrame(row_buf).sort_values(by="layer") - df.to_csv(f, index=False) + if df_buf: + df = pd.concat(df_buf).sort_values(by="layer") + df.to_csv(f, index=False) if self.cfg.debug: - save_debug_log(self.dataset, self.out_dir) + save_debug_log(self.datasets, self.out_dir) diff --git a/elk/training/__init__.py b/elk/training/__init__.py index 41264179..6428c3c5 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -1,13 +1,15 @@ from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig +from .normalizer import Normalizer from .reporter import OptimConfig, Reporter, ReporterConfig __all__ = [ - "Reporter", - "ReporterConfig", "CcsReporter", "CcsReporterConfig", "EigenReporter", "EigenReporterConfig", + "Normalizer", "OptimConfig", + "Reporter", + "ReporterConfig", ] diff --git a/elk/training/baseline.py b/elk/training/baseline.py index 2c9542a6..6a40b7bd 100644 --- a/elk/training/baseline.py +++ b/elk/training/baseline.py @@ -1,5 +1,3 @@ -import pickle -from pathlib import Path from typing import Tuple import torch @@ -9,8 +7,6 @@ from ..utils.typing import assert_type from .classifier import Classifier -# TODO: Create class for baseline? - def evaluate_baseline( lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_labels: Tensor @@ -50,13 +46,3 @@ def train_baseline( lr_model.fit_cv(X.view(-1, d), train_labels_aug) return lr_model - - -def save_baseline(lr_dir: Path, layer: int, lr_model: Classifier): - with open(lr_dir / f"layer_{layer}.pt", "wb") as file: - pickle.dump(lr_model, file) - - -def load_baseline(lr_dir: Path, layer: int) -> Classifier: - with open(lr_dir / f"layer_{layer}.pt", "rb") as file: - return pickle.load(file) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 258f2bc4..bf98d29a 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -13,6 +13,7 @@ from ..parsing import parse_loss from ..utils.typing import assert_type from .losses import LOSSES +from .normalizer import Normalizer from .reporter import Reporter, ReporterConfig @@ -85,13 +86,17 @@ def __init__( self, in_features: int, cfg: CcsReporterConfig, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, ): - super().__init__(in_features, cfg, device=device, dtype=dtype) + super().__init__() + self.config = cfg hidden_size = cfg.hidden_size or 4 * in_features // 3 + self.neg_norm = Normalizer((in_features,), device=device, dtype=dtype) + self.pos_norm = Normalizer((in_features,), device=device, dtype=dtype) + self.probe = nn.Sequential( nn.Linear( in_features, @@ -165,6 +170,8 @@ def forward(self, x: Tensor) -> Tensor: return self.probe(x).squeeze(-1) def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: + x_pos = self.pos_norm(x_pos) + x_neg = self.neg_norm(x_neg) return 0.5 * (self(x_pos).sigmoid() + (1 - self(x_neg).sigmoid())) def loss( @@ -231,8 +238,9 @@ def fit( ValueError: If `optimizer` is not "adam" or "lbfgs". RuntimeError: If the best loss is not finite. """ - # TODO: Implement normalization here to fix issue #96 - # self.update(x_pos, x_neg) + # Fit normalizers + self.pos_norm.fit(x_pos) + self.neg_norm.fit(x_neg) # Record the best acc, loss, and params found so far best_loss = torch.inf diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 821891cc..3d640056 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -7,8 +7,8 @@ import torch from torch import Tensor, nn, optim -from ..math_util import cov_mean_fused from ..truncated_eigh import ConvergenceError, truncated_eigh +from ..utils.math_util import cov_mean_fused from .reporter import Reporter, ReporterConfig @@ -59,6 +59,8 @@ class EigenReporter(Reporter): intracluster_cov: Tensor # invariance contrastive_xcov_M2: Tensor # negative covariance n: Tensor + neg_mean: Tensor + pos_mean: Tensor weight: Tensor def __init__( @@ -68,12 +70,22 @@ def __init__( device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): - super().__init__(in_features, cfg, device=device, dtype=dtype) + super().__init__() + self.config = cfg # Learnable Platt scaling parameters self.bias = nn.Parameter(torch.zeros(cfg.num_heads, device=device, dtype=dtype)) self.scale = nn.Parameter(torch.ones(cfg.num_heads, device=device, dtype=dtype)) + # Running statistics + self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long)) + self.register_buffer( + "neg_mean", torch.zeros(in_features, device=device, dtype=dtype) + ) + self.register_buffer( + "pos_mean", torch.zeros(in_features, device=device, dtype=dtype) + ) + self.register_buffer( "contrastive_xcov_M2", torch.zeros(in_features, in_features, device=device, dtype=dtype), @@ -86,6 +98,8 @@ def __init__( "intracluster_cov", torch.zeros(in_features, in_features, device=device, dtype=dtype), ) + + # Reporter weights self.register_buffer( "weight", torch.zeros(cfg.num_heads, in_features, device=device, dtype=dtype), diff --git a/elk/training/normalizer.py b/elk/training/normalizer.py new file mode 100644 index 00000000..7cb04b97 --- /dev/null +++ b/elk/training/normalizer.py @@ -0,0 +1,63 @@ +from typing import Literal + +import torch +from torch import Tensor, nn + + +class Normalizer(nn.Module): + """Basically `BatchNorm` with a less annoying default axis ordering.""" + + mean: Tensor + std: Tensor | None + + def __init__( + self, + normalized_shape: tuple[int, ...], + *, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + eps: float = 1e-5, + mode: Literal["none", "meanonly", "full"] = "full", + ): + super().__init__() + + self.eps = eps + self.mode = mode + self.normalized_shape = normalized_shape + + self.register_buffer( + "mean", torch.zeros(*normalized_shape, device=device, dtype=dtype) + ) + self.register_buffer( + "std", torch.ones_like(self.mean) if mode == "full" else None + ) + + def forward(self, x: Tensor) -> Tensor: + """Normalize `x` using the stored mean and standard deviation.""" + if self.mode == "none": + return x + elif self.std is None: + return x - self.mean + else: + return (x - self.mean) / self.std + + def fit(self, x: Tensor) -> None: + """Update the stored mean and standard deviation.""" + + # Check shape + num_dims = len(self.normalized_shape) + if x.shape[-num_dims:] != self.normalized_shape: + raise ValueError( + f"Expected trailing sizes {self.normalized_shape} but got " + f"{x.shape[-num_dims:]}" + ) + + if self.mode == "none": + return + + dims = [i for i in range(x.ndim - num_dims)] + if self.std is None: + torch.mean(x, dim=dims, out=self.mean) + else: + variance, self.mean = torch.var_mean(x, dim=dims) + torch.sqrt(variance + self.eps, out=self.std) diff --git a/elk/training/preprocessing.py b/elk/training/preprocessing.py deleted file mode 100644 index 6081dcbb..00000000 --- a/elk/training/preprocessing.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Preprocessing functions for training.""" - -from typing import Literal - -import torch - - -def normalize( - train_hiddens: torch.Tensor, - val_hiddens: torch.Tensor, - method: Literal["legacy", "none", "elementwise", "meanonly"] = "legacy", -) -> tuple[torch.Tensor, torch.Tensor]: - """Normalize the hidden states. - - Normalize the hidden states with the specified method. - The "legacy" method is the same as the original ELK implementation. - The "elementwise" method normalizes each element. - The "meanonly" method normalizes the mean. - - Args: - train_hiddens: The hidden states for the training set. - val_hiddens: The hidden states for the validation set. - method: The normalization method to use. - - Returns: - tuple containing the training and validation hidden states. - """ - if method == "none": - return train_hiddens, val_hiddens - elif method == "legacy": - master = torch.cat([train_hiddens, val_hiddens], dim=0).float() - means = master.mean(dim=0) - - train_hiddens -= means - val_hiddens -= means - - scale = master.shape[-1] ** 0.5 / master.norm(dim=-1).mean() - train_hiddens *= scale - val_hiddens *= scale - else: - means = train_hiddens.float().mean(dim=0) - train_hiddens -= means - val_hiddens -= means - - if method == "elementwise": - scale = 1 / train_hiddens.norm(dim=0, keepdim=True) - elif method == "meanonly": - scale = 1 - else: - raise NotImplementedError(f"Scale method '{method}' is not supported.") - - train_hiddens *= scale - val_hiddens *= scale - - return train_hiddens, val_hiddens diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 9cdfb145..b168acb5 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -66,26 +66,7 @@ class Reporter(nn.Module, ABC): """ n: Tensor - neg_mean: Tensor - pos_mean: Tensor - - def __init__( - self, - in_features: int, - cfg: ReporterConfig, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): - super().__init__() - - self.config = cfg - self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long)) - self.register_buffer( - "neg_mean", torch.zeros(in_features, device=device, dtype=dtype) - ) - self.register_buffer( - "pos_mean", torch.zeros(in_features, device=device, dtype=dtype) - ) + config: ReporterConfig @classmethod def check_separability( @@ -138,17 +119,6 @@ def check_separability( def reset_parameters(self): """Reset the parameters of the probe.""" - @torch.no_grad() - def update(self, x_pos: Tensor, x_neg: Tensor) -> None: - """Update the running mean of the positive and negative examples.""" - - x_pos, x_neg = x_pos.flatten(0, -2), x_neg.flatten(0, -2) - self.n += x_pos.shape[0] - - # Update the running means - self.neg_mean += (x_neg.sum(dim=0) - self.neg_mean) / self.n - self.pos_mean += (x_pos.sum(dim=0) - self.pos_mean) / self.n - # TODO: These methods will do something fancier in the future @classmethod def load(cls, path: Path | str): diff --git a/elk/training/train.py b/elk/training/train.py index 9be8c589..351b7374 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Literal, Optional +from typing import Callable, Optional import pandas as pd import torch @@ -14,7 +14,7 @@ from ..extraction.extraction import Extract from ..run import Run -from ..training.baseline import evaluate_baseline, save_baseline, train_baseline +from ..training.baseline import evaluate_baseline, train_baseline from ..utils import select_usable_devices from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig @@ -47,7 +47,7 @@ class Elicit(Serializable): optim: OptimConfig = field(default_factory=OptimConfig) num_gpus: int = -1 - normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" + min_gpu_mem: int | None = None skip_baseline: bool = False concatenated_layer_offset: int = 0 # if nonzero, appends the hidden states of layer concatenated_layer_offset before @@ -78,16 +78,14 @@ def train_reporter( layer: int, devices: list[str], world_size: int = 1, - ) -> pd.Series: + ) -> pd.DataFrame: """Train a single reporter on a single layer.""" self.make_reproducible(seed=self.cfg.net.seed + layer) - device = self.get_device(devices, world_size) - x0, x1, val_x0, val_x1, train_gt, val_gt, val_lm_preds = self.prepare_data( - device, layer - ) - pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1) + x0, x1, train_gt, val_output = self.prepare_data(device, layer) + reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) + # pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1) if isinstance(self.cfg.net, CcsReporterConfig): reporter = CcsReporter(x0.shape[-1], self.cfg.net, device=device) @@ -96,46 +94,55 @@ def train_reporter( else: raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") + # Fit reporter train_loss = reporter.fit(x0, x1, train_gt) - val_result = reporter.score( - val_gt, - val_x0, - val_x1, - ) + with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: + torch.save(reporter, file) - reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - if val_lm_preds is not None: - val_gt_cpu = val_gt.repeat_interleave(val_lm_preds.shape[1]).float().cpu() - val_lm_auroc = float(roc_auc_score(val_gt_cpu, val_lm_preds.flatten())) - val_lm_acc = float(accuracy_score(val_gt_cpu, val_lm_preds.flatten() > 0.5)) - else: - val_lm_auroc = None - val_lm_acc = None - - row = pd.Series( - { - "layer": layer, - "pseudo_auroc": pseudo_auroc, - "train_loss": train_loss, - **val_result._asdict(), - "lm_auroc": val_lm_auroc, - "lm_acc": val_lm_acc, - } - ) + # Fit baseline logistic regression model + lr_model = train_baseline(x0, x1, train_gt, device=device) + with open(lr_dir / f"layer_{layer}.pt", "wb") as file: + torch.save(lr_model, file) + + row_buf = [] + for ds_name, (val_x0, val_x1, val_gt, val_lm_preds) in val_output.items(): + val_result = reporter.score( + val_gt, + val_x0, + val_x1, + ) - if not self.cfg.skip_baseline: - lr_model = train_baseline(x0, x1, train_gt, device=device) + if val_lm_preds is not None: + val_gt_cpu = ( + val_gt.repeat_interleave(val_lm_preds.shape[1]).float().cpu() + ) + val_lm_auroc = float(roc_auc_score(val_gt_cpu, val_lm_preds.flatten())) + val_lm_acc = float( + accuracy_score(val_gt_cpu, val_lm_preds.flatten() > 0.5) + ) + else: + val_lm_auroc = None + val_lm_acc = None + + row = pd.Series( + { + "dataset": ds_name, + "layer": layer, + # "pseudo_auroc": pseudo_auroc, + "train_loss": train_loss, + **val_result._asdict(), + "lm_auroc": val_lm_auroc, + "lm_acc": val_lm_acc, + } + ) lr_auroc, lr_acc = evaluate_baseline(lr_model, val_x0, val_x1, val_gt) row["lr_auroc"] = lr_auroc row["lr_acc"] = lr_acc - save_baseline(lr_dir, layer, lr_model) - - with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: - torch.save(reporter, file) + row_buf.append(row) - return row + return pd.DataFrame(row_buf) def get_pseudo_auroc( self, layer: int, x0: Tensor, x1: Tensor, val_x0: Tensor, val_x1: Tensor @@ -159,7 +166,7 @@ def train(self): """Train a reporter on each layer of the network.""" devices = select_usable_devices(self.cfg.num_gpus) num_devices = len(devices) - func: Callable[[int], pd.Series] = partial( + func: Callable[[int], pd.DataFrame] = partial( self.train_reporter, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index 1400a98d..7d4bbf2f 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -2,6 +2,8 @@ binarize, convert_span, get_columns_all_equal, + get_dataset_name, + get_layers, infer_label_column, infer_num_classes, select_train_val_splits, @@ -15,6 +17,8 @@ "binarize", "convert_span", "get_columns_all_equal", + "get_dataset_name", + "get_layers", "infer_label_column", "infer_num_classes", "instantiate_model", diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index a98a7aae..3f5f7457 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -2,7 +2,7 @@ from bisect import bisect_left, bisect_right from operator import itemgetter from random import Random -from typing import Any, Iterable, List +from typing import Any, Iterable from datasets import ( ClassLabel, @@ -44,6 +44,23 @@ def get_columns_all_equal(dataset: DatasetDict) -> list[str]: return pivot +def get_dataset_name(dataset: DatasetDict) -> str: + """Get the name of a `DatasetDict`.""" + builder_name, *rest = [ds.builder_name for ds in dataset.values()] + if not all(name == builder_name for name in rest): + raise ValueError( + f"All splits must have the same name; got {[builder_name, *rest]}" + ) + + config_name, *rest = [ds.config_name for ds in dataset.values()] + if not all(name == config_name for name in rest): + raise ValueError( + f"All splits must have the same config name; got {[config_name, *rest]}" + ) + + return builder_name + " " + config_name if config_name else builder_name + + def select_train_val_splits(raw_splits: Iterable[str]) -> tuple[str, str]: """Return splits to use for train and validation, given an Iterable of splits.""" @@ -101,7 +118,7 @@ def infer_num_classes(label_feature: Any) -> int: ) -def get_layers(ds: DatasetDict) -> List[int]: +def get_layers(ds: DatasetDict) -> list[int]: """Get a list of indices of hidden layers given a `DatasetDict`.""" layers = [ int(feat[len("hidden_") :]) diff --git a/elk/math_util.py b/elk/utils/math_util.py similarity index 100% rename from elk/math_util.py rename to elk/utils/math_util.py diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index 5bbcfef3..9680781c 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -14,10 +14,10 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): data=Extract( model=model_path, prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), - min_gpu_mem=min_mem, # run on all layers, tiny-gpt only has 2 layers ), num_gpus=2, + min_gpu_mem=min_mem, net=CcsReporterConfig(), out_dir=tmp_path, ) @@ -38,10 +38,10 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): data=Extract( model=model_path, prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), - min_gpu_mem=min_mem, # run on all layers, tiny-gpt only has 2 layers ), num_gpus=2, + min_gpu_mem=min_mem, net=EigenReporterConfig(), out_dir=tmp_path, ) From e1675f7827bc0a12a51df690987713aaa6e556c0 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 13 Apr 2023 19:43:28 +0000 Subject: [PATCH 26/43] Various fixes --- elk/evaluation/evaluate.py | 10 ++-- elk/extraction/extraction.py | 12 +++-- elk/run.py | 47 +++++++---------- elk/training/baseline.py | 48 ----------------- elk/training/eigen_reporter.py | 14 ++--- elk/training/supervised.py | 47 +++++++++++++++++ elk/training/train.py | 94 +++++++++++++++++++++++----------- elk/utils/__init__.py | 2 + elk/utils/data_utils.py | 11 +++- elk/utils/hf_utils.py | 11 ++-- pyproject.toml | 1 + 11 files changed, 165 insertions(+), 132 deletions(-) delete mode 100644 elk/training/baseline.py create mode 100644 elk/training/supervised.py diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index ad81182a..186e10a4 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -11,7 +11,7 @@ from ..files import elk_reporter_dir from ..run import Run from ..training import Reporter -from ..training.baseline import evaluate_baseline +from ..training.supervised import evaluate_supervised from ..utils import select_usable_devices @@ -61,11 +61,7 @@ def evaluate_reporter( ) -> pd.DataFrame: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) - - _, _, _, val_output = self.prepare_data( - device, - layer, - ) + val_output = self.prepare_data(device, layer, "val") experiment_dir = elk_reporter_dir() / self.cfg.source @@ -94,7 +90,7 @@ def evaluate_reporter( with open(lr_dir / f"layer_{layer}.pt", "rb") as f: lr_model = torch.load(f, map_location=device).eval() - lr_auroc, lr_acc = evaluate_baseline(lr_model, val_x0, val_x1, val_gt) + lr_auroc, lr_acc = evaluate_supervised(lr_model, val_x0, val_x1, val_gt) stats_row["lr_auroc"] = lr_auroc stats_row["lr_acc"] = lr_acc diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 03105c74..e022e5ce 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -23,6 +23,7 @@ from transformers import AutoConfig, AutoTokenizer from transformers.modeling_outputs import Seq2SeqLMOutput +from ..promptsource import DatasetTemplates from ..utils import ( assert_type, convert_span, @@ -257,8 +258,9 @@ def get_splits() -> SplitDict: available_splits = assert_type(SplitDict, info.splits) train_name, val_name = select_train_val_splits(available_splits) print( - f"{info.builder_name}: using '{train_name}' for training and '{val_name}'" - f" for validation" + # Cyan color for dataset name + f"\033[36m{info.builder_name}\033[0m: using '{train_name}' for training and" + f" '{val_name}' for validation" ) limit_list = cfg.prompts.max_examples @@ -275,11 +277,15 @@ def get_splits() -> SplitDict: ) model_cfg = AutoConfig.from_pretrained(cfg.model) - num_variants = cfg.prompts.num_variants ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") info = get_dataset_config_info(ds_name, config_name or None) + num_variants = cfg.prompts.num_variants + if num_variants < 0: + prompter = DatasetTemplates(ds_name, config_name) + num_variants = len(prompter.templates) + layer_cols = { f"hidden_{layer}": Array3D( dtype="int16", diff --git a/elk/run.py b/elk/run.py index 3679e32c..2ea30eb0 100644 --- a/elk/run.py +++ b/elk/run.py @@ -3,14 +3,14 @@ from abc import ABC from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Callable, Literal, Union import numpy as np import pandas as pd import torch import torch.multiprocessing as mp import yaml -from datasets import DatasetDict, concatenate_datasets +from datasets import DatasetDict from torch import Tensor from tqdm import tqdm @@ -84,40 +84,29 @@ def get_device(self, devices, world_size: int) -> str: device = devices[rank] return device - def prepare_data(self, device: str, layer: int) -> tuple: - """Prepare the data for training and validation.""" + def prepare_data( + self, device: str, layer: int, split_type: Literal["train", "val"] + ) -> dict[str, tuple[Tensor, Tensor, Tensor, np.ndarray | None]]: + """Prepare data for the specified layer and split type.""" + out = {} - train_sets = [] - val_output = {} - - # We handle train and val differently. We want to concatenate all of the - # train sets together, but we want to keep the val sets separate so that we can - # compute evaluation metrics separately for each dataset. for ds in self.datasets: - train_split, val_split = select_train_val_splits(ds) - train_sets.append(ds[train_split]) + train_name, val_name = select_train_val_splits(ds) + key = train_name if split_type == "train" else val_name - val = ds[val_split].with_format("torch", device=device, dtype=torch.int16) - val_labels = assert_type(Tensor, val["label"]) - val_h = int16_to_float32(assert_type(torch.Tensor, val[f"hidden_{layer}"])) - val_x0, val_x1 = val_h.unbind(dim=-2) + split = ds[key].with_format("torch", device=device, dtype=torch.int16) + labels = assert_type(Tensor, split["label"]) + val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) + x0, x1 = val_h.unbind(dim=-2) - with val.formatted_as("numpy"): - has_preds = "model_preds" in val.features - val_lm_preds = val["model_preds"] if has_preds else None + with split.formatted_as("numpy"): + has_preds = "model_preds" in split.features + lm_preds = split["model_preds"] if has_preds else None ds_name = get_dataset_name(ds) - val_output[ds_name] = (val_x0, val_x1, val_labels, val_lm_preds) - - train = concatenate_datasets(train_sets).with_format( - "torch", device=device, dtype=torch.int16 - ) - - train_labels = assert_type(Tensor, train["label"]) - train_h = int16_to_float32(assert_type(torch.Tensor, train[f"hidden_{layer}"])) - x0, x1 = train_h.unbind(dim=-2) + out[ds_name] = (x0, x1, labels, lm_preds) - return x0, x1, train_labels, val_output + return out def concatenate(self, layers): """Concatenate hidden states from a previous layer.""" diff --git a/elk/training/baseline.py b/elk/training/baseline.py deleted file mode 100644 index 6a40b7bd..00000000 --- a/elk/training/baseline.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Tuple - -import torch -from sklearn.metrics import accuracy_score, roc_auc_score -from torch import Tensor - -from ..utils.typing import assert_type -from .classifier import Classifier - - -def evaluate_baseline( - lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_labels: Tensor -) -> Tuple[float, float]: - X = torch.cat([val_x0, val_x1]) - d = X.shape[-1] - X_val = X.view(-1, d) - with torch.no_grad(): - lr_preds = lr_model(X_val).sigmoid().cpu() - - val_labels_aug = ( - torch.cat([val_labels, 1 - val_labels]).repeat_interleave(val_x0.shape[1]) - ).cpu() - - lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) - lr_auroc = roc_auc_score(val_labels_aug, lr_preds) - - return assert_type(float, lr_auroc), assert_type(float, lr_acc) - - -def train_baseline( - x0: Tensor, - x1: Tensor, - train_labels: Tensor, - device: str, -) -> Classifier: - # repeat_interleave makes `num_variants` copies of each label, all within a - # single dimension of size `num_variants * 2 * n`, such that the labels align - # with X.view(-1, X.shape[-1]) - train_labels_aug = torch.cat([train_labels, 1 - train_labels]).repeat_interleave( - x0.shape[1] - ) - - X = torch.cat([x0, x1]).squeeze() - d = X.shape[-1] - lr_model = Classifier(d, device=device) - lr_model.fit_cv(X.view(-1, d), train_labels_aug) - - return lr_model diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 3d640056..5449dd06 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -2,12 +2,11 @@ from dataclasses import dataclass from typing import Optional -from warnings import warn import torch from torch import Tensor, nn, optim -from ..truncated_eigh import ConvergenceError, truncated_eigh +from ..truncated_eigh import truncated_eigh from ..utils.math_util import cov_mean_fused from .reporter import Reporter, ReporterConfig @@ -184,7 +183,7 @@ def update(self, x_pos: Tensor, x_neg: Tensor) -> None: self.contrastive_xcov_M2.addmm_(neg_delta.mT, pos_delta2) self.contrastive_xcov_M2.addmm_(pos_delta.mT, neg_delta2) - def fit_streaming(self) -> float: + def fit_streaming(self, truncated: bool = False) -> float: """Fit the probe using the current streaming statistics.""" A = ( self.config.var_weight * self.intercluster_cov @@ -192,14 +191,9 @@ def fit_streaming(self) -> float: - self.config.neg_cov_weight * self.contrastive_xcov ) - try: + if truncated: L, Q = truncated_eigh(A, k=self.config.num_heads) - except (ConvergenceError, RuntimeError): - warn( - "Truncated eigendecomposition failed to converge. Falling back on " - "PyTorch's dense eigensolver." - ) - + else: L, Q = torch.linalg.eigh(A) L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :] diff --git a/elk/training/supervised.py b/elk/training/supervised.py new file mode 100644 index 00000000..b8f580f3 --- /dev/null +++ b/elk/training/supervised.py @@ -0,0 +1,47 @@ +import torch +from einops import rearrange, repeat +from sklearn.metrics import accuracy_score, roc_auc_score +from torch import Tensor + +from ..utils import assert_type +from .classifier import Classifier + + +def evaluate_supervised( + lr_model: Classifier, val_x0: Tensor, val_x1: Tensor, val_labels: Tensor +) -> tuple[float, float]: + X = torch.cat([val_x0, val_x1]) + d = X.shape[-1] + X_val = X.view(-1, d) + with torch.no_grad(): + lr_preds = lr_model(X_val).sigmoid().cpu() + + val_labels_aug = ( + torch.cat([val_labels, 1 - val_labels]).repeat_interleave(val_x0.shape[1]) + ).cpu() + + lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) + lr_auroc = roc_auc_score(val_labels_aug, lr_preds) + + return assert_type(float, lr_auroc), assert_type(float, lr_acc) + + +def train_supervised(data: dict[str, tuple], device: str) -> Classifier: + Xs, train_labels = [], [] + + for x0, x1, labels, _ in data.values(): + (_, v, _) = x0.shape + x0 = rearrange(x0, "n v d -> (n v) d") + x1 = rearrange(x1, "n v d -> (n v) d") + + labels = repeat(labels, "n -> (n v)", v=v) + labels = torch.cat([labels, 1 - labels]) + + Xs.append(torch.cat([x0, x1]).squeeze()) + train_labels.append(labels) + + X, train_labels = torch.cat(Xs), torch.cat(train_labels) + lr_model = Classifier(X.shape[-1], device=device) + lr_model.fit_cv(X, train_labels) + + return lr_model diff --git a/elk/training/train.py b/elk/training/train.py index 351b7374..111e2fd9 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,6 +1,5 @@ """Main training loop.""" -import warnings from dataclasses import dataclass from functools import partial from pathlib import Path @@ -8,18 +7,19 @@ import pandas as pd import torch +from einops import rearrange, repeat from simple_parsing import Serializable, field, subgroups from sklearn.metrics import accuracy_score, roc_auc_score -from torch import Tensor from ..extraction.extraction import Extract from ..run import Run -from ..training.baseline import evaluate_baseline, train_baseline +from ..training.supervised import evaluate_supervised, train_supervised from ..utils import select_usable_devices from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig +from .classifier import Classifier from .eigen_reporter import EigenReporter, EigenReporterConfig -from .reporter import OptimConfig, Reporter, ReporterConfig +from .reporter import OptimConfig, ReporterConfig @dataclass @@ -34,7 +34,7 @@ class Elicit(Serializable): "use all available GPUs". normalization: The normalization method to use. Defaults to "meanonly". See `elk.training.preprocessing.normalize()` for details. - skip_baseline: Whether to skip training the baseline classifier. Defaults to + skip_baseline: Whether to skip training the supervised classifier. Defaults to False. debug: When in debug mode, a useful log file is saved to the memorably-named output directory. Defaults to False. @@ -83,34 +83,65 @@ def train_reporter( self.make_reproducible(seed=self.cfg.net.seed + layer) device = self.get_device(devices, world_size) - x0, x1, train_gt, val_output = self.prepare_data(device, layer) + train_dict = self.prepare_data(device, layer, "train") + val_dict = self.prepare_data(device, layer, "val") + + # Can't figure out a way to make this line less ugly + hidden_size = next(iter(train_dict.values()))[0].shape[-1] reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - # pseudo_auroc = self.get_pseudo_auroc(layer, x0, x1, val_x0, val_x1) + pseudo_clf = self.get_pseudo_classifier(train_dict, device) if isinstance(self.cfg.net, CcsReporterConfig): - reporter = CcsReporter(x0.shape[-1], self.cfg.net, device=device) + assert len(train_dict) == 1, "CCS only supports single-task training" + + reporter = CcsReporter(hidden_size, self.cfg.net, device=device) + (x0, x1, labels, _) = next(iter(train_dict.values())) + train_loss = reporter.fit(x0, x1, labels) + elif isinstance(self.cfg.net, EigenReporterConfig): - reporter = EigenReporter(x0.shape[-1], self.cfg.net, device=device) + # To enable training on multiple tasks with different numbers of variants, + # we update the statistics in a streaming fashion and then fit + reporter = EigenReporter(hidden_size, self.cfg.net, device=device) + for ds_name, (x0, x1, labels, _) in train_dict.items(): + reporter.update(x0, x1) + + train_loss = reporter.fit_streaming() else: raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") - # Fit reporter - train_loss = reporter.fit(x0, x1, train_gt) + # Save reporter checkpoint to disk with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: torch.save(reporter, file) - # Fit baseline logistic regression model - lr_model = train_baseline(x0, x1, train_gt, device=device) + # Fit supervised logistic regression model + lr_model = train_supervised(train_dict, device=device) with open(lr_dir / f"layer_{layer}.pt", "wb") as file: torch.save(lr_model, file) row_buf = [] - for ds_name, (val_x0, val_x1, val_gt, val_lm_preds) in val_output.items(): + for ds_name, (val_x0, val_x1, val_gt, val_lm_preds) in val_dict.items(): val_result = reporter.score( val_gt, val_x0, val_x1, ) + with torch.no_grad(): + (n, v, d) = val_x0.shape + + pseudo_preds = pseudo_clf( + # b v d -> (b v) d + torch.cat([val_x0, val_x1]).flatten(0, 1) + ) + pseudo_labels = torch.cat( + [ + val_x0.new_zeros(n), + val_x0.new_ones(n), + ] + ) + pseudo_labels = repeat(pseudo_labels, "n -> (n v)", v=v) + pseudo_auroc = float( + roc_auc_score(pseudo_labels.cpu(), pseudo_preds.cpu()) + ) if val_lm_preds is not None: val_gt_cpu = ( @@ -128,7 +159,7 @@ def train_reporter( { "dataset": ds_name, "layer": layer, - # "pseudo_auroc": pseudo_auroc, + "pseudo_auroc": pseudo_auroc, "train_loss": train_loss, **val_result._asdict(), "lm_auroc": val_lm_auroc, @@ -136,7 +167,7 @@ def train_reporter( } ) - lr_auroc, lr_acc = evaluate_baseline(lr_model, val_x0, val_x1, val_gt) + lr_auroc, lr_acc = evaluate_supervised(lr_model, val_x0, val_x1, val_gt) row["lr_auroc"] = lr_auroc row["lr_acc"] = lr_acc @@ -144,23 +175,26 @@ def train_reporter( return pd.DataFrame(row_buf) - def get_pseudo_auroc( - self, layer: int, x0: Tensor, x1: Tensor, val_x0: Tensor, val_x1: Tensor - ): + def get_pseudo_classifier(self, data: dict[str, tuple], device: str) -> Classifier: """Check the separability of the pseudo-labels at a given layer.""" - with torch.no_grad(): - pseudo_auroc = Reporter.check_separability( - train_pair=(x0, x1), val_pair=(val_x0, val_x1) - ) - if pseudo_auroc > 0.6: - warnings.warn( - f"The pseudo-labels at layer {layer} are linearly separable with " - f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the " - f"algorithm will not converge to a good solution." - ) + x0s, x1s = [], [] + for x0, x1, _, _ in data.values(): + x0s.append(rearrange(x0, "n v d -> (n v) d")) + x1s.append(rearrange(x1, "n v d -> (n v) d")) + + # Simple de-meaning normalization + X0 = torch.cat(x0s) + X1 = torch.cat(x1s) + X0 -= X0.mean(dim=0) + X1 -= X1.mean(dim=0) + + X = torch.cat([X0, X1]) + Y = torch.cat([X0.new_zeros(X0.shape[0]), X0.new_ones(X1.shape[0])]) - return pseudo_auroc + pseudo_clf = Classifier(X.shape[-1], device=device) + pseudo_clf.fit(X, Y) + return pseudo_clf def train(self): """Train a reporter on each layer of the network.""" diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index 7d4bbf2f..569850f8 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -4,6 +4,7 @@ get_columns_all_equal, get_dataset_name, get_layers, + has_multiple_configs, infer_label_column, infer_num_classes, select_train_val_splits, @@ -19,6 +20,7 @@ "get_columns_all_equal", "get_dataset_name", "get_layers", + "has_multiple_configs", "infer_label_column", "infer_num_classes", "instantiate_model", diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 3f5f7457..0fbd7353 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -1,5 +1,6 @@ import copy from bisect import bisect_left, bisect_right +from functools import cache from operator import itemgetter from random import Random from typing import Any, Iterable @@ -10,6 +11,7 @@ Features, Split, Value, + get_dataset_config_names, ) from ..promptsource.templates import Template @@ -58,7 +60,14 @@ def get_dataset_name(dataset: DatasetDict) -> str: f"All splits must have the same config name; got {[config_name, *rest]}" ) - return builder_name + " " + config_name if config_name else builder_name + include_config = config_name and has_multiple_configs(builder_name) + return builder_name + " " + config_name if include_config else builder_name + + +@cache +def has_multiple_configs(ds_name: str) -> bool: + """Return whether a dataset has multiple configs.""" + return len(get_dataset_config_names(ds_name)) > 1 def select_train_val_splits(raw_splits: Iterable[str]) -> tuple[str, str]: diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 4c3ab331..4e97b6ee 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -1,8 +1,6 @@ import transformers from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel -from .typing import assert_type - # Ordered by preference _AUTOREGRESSIVE_SUFFIXES = [ # Encoder-decoder models @@ -16,7 +14,9 @@ def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel: """Instantiate a model string with the appropriate `Auto` class.""" model_cfg = AutoConfig.from_pretrained(model_str) - archs = assert_type(list, model_cfg.architectures) + archs = model_cfg.architectures + if not isinstance(archs, list): + return AutoModel.from_pretrained(model_str, **kwargs) for suffix in _AUTOREGRESSIVE_SUFFIXES: # Check if any of the architectures in the config end with the suffix. @@ -31,7 +31,10 @@ def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel: def is_autoregressive(model_cfg: PretrainedConfig) -> bool: """Check if a model config is autoregressive.""" - archs = assert_type(list, model_cfg.architectures) + archs = model_cfg.architectures + if not isinstance(archs, list): + return False + return any( arch_str.endswith(suffix) for arch_str in archs diff --git a/pyproject.toml b/pyproject.toml index 16edd58e..f688416d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ license = {text = "MIT License"} dependencies = [ # Added distributed.split_dataset_by_node for IterableDatasets "datasets>=2.9.0", + "einops", # Introduced numpy.typing module "numpy>=1.20.0", # This version is old, but it's needed for certain HF tokenizers to work. From 14987e1690b3f7da83ff4d6b52b909713255fabb Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 13 Apr 2023 20:08:08 +0000 Subject: [PATCH 27/43] Fix math tests --- elk/utils/__init__.py | 12 ++++++++---- tests/test_eigen_reporter.py | 2 +- tests/test_math.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index 569850f8..13656933 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -11,12 +11,17 @@ ) from .gpu_utils import select_usable_devices from .hf_utils import instantiate_model, is_autoregressive +from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained from .tree_utils import pytree_map from .typing import assert_type, float32_to_int16, int16_to_float32 __all__ = [ + "assert_type", + "batch_cov", "binarize", "convert_span", + "cov_mean_fused", + "float32_to_int16", "get_columns_all_equal", "get_dataset_name", "get_layers", @@ -24,11 +29,10 @@ "infer_label_column", "infer_num_classes", "instantiate_model", - "is_autoregressive", - "float32_to_int16", "int16_to_float32", + "is_autoregressive", + "pytree_map", "select_train_val_splits", "select_usable_devices", - "pytree_map", - "assert_type", + "stochastic_round_constrained", ] diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index 58dd6c13..bf00377d 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -1,7 +1,7 @@ import torch -from elk.math_util import batch_cov, cov_mean_fused from elk.training import EigenReporter, EigenReporterConfig +from elk.utils import batch_cov, cov_mean_fused def test_eigen_reporter(): diff --git a/tests/test_math.py b/tests/test_math.py index ee81914e..34984d8f 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -6,7 +6,7 @@ from hypothesis import given from hypothesis import strategies as st -from elk.math_util import batch_cov, cov_mean_fused, stochastic_round_constrained +from elk.utils import batch_cov, cov_mean_fused, stochastic_round_constrained def test_cov_mean_fused(): From 88683faef002b4c91a0d0752c51c16688fa769cd Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 13 Apr 2023 20:16:25 +0000 Subject: [PATCH 28/43] Fix smoke tests --- tests/test_smoke_elicit.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index 9680781c..e61568ba 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -25,7 +25,13 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): # get the files in the tmp_path files: list[Path] = list(tmp_path.iterdir()) created_file_names = {file.name for file in files} - expected_files = ["cfg.yaml", "metadata.yaml", "lr_models", "reporters", "eval.csv"] + expected_files = [ + "cfg.yaml", + "fingerprints.yaml", + "lr_models", + "reporters", + "eval.csv", + ] for file in expected_files: assert file in created_file_names @@ -49,6 +55,12 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): # get the files in the tmp_path files: list[Path] = list(tmp_path.iterdir()) created_file_names = {file.name for file in files} - expected_files = ["cfg.yaml", "metadata.yaml", "lr_models", "reporters", "eval.csv"] + expected_files = [ + "cfg.yaml", + "fingerprints.yaml", + "lr_models", + "reporters", + "eval.csv", + ] for file in expected_files: assert file in created_file_names From a6c382e2deaadaa45dc6cee212e96c0c481657c7 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 13 Apr 2023 20:29:35 +0000 Subject: [PATCH 29/43] All tests working ostensibly --- tests/test_load_prompts.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py index a5d238fd..2e03c379 100644 --- a/tests/test_load_prompts.py +++ b/tests/test_load_prompts.py @@ -1,4 +1,4 @@ -from itertools import cycle, islice +from itertools import islice from typing import Literal import pytest @@ -10,26 +10,22 @@ @pytest.mark.filterwarnings("ignore:Unable to find a decoding function") def test_load_prompts(): def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): - prompt_ds = load_prompts( - *cfg.datasets, - split_type=split_type, - ) - prompters = [] - - for ds in cfg.datasets: - ds_name, _, config_name = ds.partition(" ") + for cfg in cfg.explode(): + ds_string = cfg.datasets[0] + prompt_ds = load_prompts(ds_string, split_type=split_type) + + ds_name, _, config_name = ds_string.partition(" ") prompter = DatasetTemplates(ds_name, config_name or None) - prompters.append(prompter) - limit = cfg.max_examples[0 if split_type == "train" else 1] - for prompter, record in zip(cycle(prompters), islice(prompt_ds, limit)): - true_template_names = prompter.all_template_names - returned_template_names = record["template_names"] + limit = cfg.max_examples[0 if split_type == "train" else 1] + for record in islice(prompt_ds, limit): + true_template_names = prompter.all_template_names + returned_template_names = record["template_names"] - # check for using the same templates - assert set(true_template_names) == set(returned_template_names) - # check for them being in the same order - assert true_template_names == true_template_names + # check for using the same templates + assert set(true_template_names) == set(returned_template_names) + # check for them being in the same order + assert true_template_names == true_template_names # the case where the dataset has 2 classes # this dataset is small From ecc53cb2429ee8c10a3a34d1e86b847bebeda858 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 13 Apr 2023 21:07:37 +0000 Subject: [PATCH 30/43] Make CCS normalization customizable --- elk/training/ccs_reporter.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index bf98d29a..5facf7d3 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -35,6 +35,7 @@ class CcsReporterConfig(ReporterConfig): Example: --loss 1.0*consistency_squared 0.5*prompt_var corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var. Defaults to "ccs_prompt_var". + normalization: The kind of normalization to apply to the hidden states. num_layers: The number of layers in the MLP. Defaults to 1. pre_ln: Whether to include a LayerNorm module before the first linear layer. Defaults to False. @@ -54,6 +55,7 @@ class CcsReporterConfig(ReporterConfig): init: Literal["default", "pca", "spherical", "zero"] = "default" loss: list[str] = field(default_factory=lambda: ["ccs"]) loss_dict: dict[str, float] = field(default_factory=dict, init=False) + normalization: Literal["none", "meanonly", "full"] = "full" num_layers: int = 1 pre_ln: bool = False seed: int = 42 @@ -94,8 +96,12 @@ def __init__( hidden_size = cfg.hidden_size or 4 * in_features // 3 - self.neg_norm = Normalizer((in_features,), device=device, dtype=dtype) - self.pos_norm = Normalizer((in_features,), device=device, dtype=dtype) + self.neg_norm = Normalizer( + (in_features,), device=device, dtype=dtype, mode=cfg.normalization + ) + self.pos_norm = Normalizer( + (in_features,), device=device, dtype=dtype, mode=cfg.normalization + ) self.probe = nn.Sequential( nn.Linear( From 18c7f4c68e1dcaf77ab5dea2d690a9f6be202f50 Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Thu, 13 Apr 2023 18:59:42 -0400 Subject: [PATCH 31/43] log each dataset individually --- elk/logging.py | 67 ++++++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/elk/logging.py b/elk/logging.py index 706055bd..263272eb 100644 --- a/elk/logging.py +++ b/elk/logging.py @@ -1,9 +1,9 @@ import logging -from .utils import select_train_val_splits +from .utils import get_dataset_name, select_train_val_splits -def save_debug_log(ds, out_dir): +def save_debug_log(datasets, out_dir): """ Save a debug log to the output directory. This is useful for debugging training issues. @@ -16,32 +16,39 @@ def save_debug_log(ds, out_dir): filemode="w", ) - train_split, val_split = select_train_val_splits(ds) - text_inputs = ds[val_split][0]["text_inputs"] - template_ids = ds[val_split][0]["variant_ids"] - label = ds[val_split][0]["label"] - - # log the train size and val size - logging.info(f"Train size: {len(ds[train_split])}") - logging.info(f"Val size: {len(ds[val_split])}") - - templates_text = f"{len(text_inputs)} templates used:\n" - trailing_whitespace = False - for (text0, text1), id in zip(text_inputs, template_ids): - templates_text += ( - f'***---TEMPLATE "{id}"---***\n' - f"{'false' if label else 'true'}:\n" - f'"""{text0}"""\n' - f"{'true' if label else 'false'}:\n" - f'"""{text1}"""\n\n\n' + for ds in datasets: + logging.info( + "=========================================\n" + f"Dataset: {get_dataset_name(ds)}\n" + "=========================================" ) - if text0[-1].isspace() or text1[-1].isspace(): - trailing_whitespace = True - if trailing_whitespace: - logging.warning( - "Some inputs to the model have trailing whitespace! " - "Check that the jinja templates are not adding " - "trailing whitespace. If `token_loc` is 'last', this " - "will extract hidden states from the whitespace token." - ) - logging.info(templates_text) + + train_split, val_split = select_train_val_splits(ds) + text_inputs = ds[val_split][0]["text_inputs"] + template_ids = ds[val_split][0]["variant_ids"] + label = ds[val_split][0]["label"] + + # log the train size and val size + logging.info(f"Train size: {len(ds[train_split])}") + logging.info(f"Val size: {len(ds[val_split])}") + + templates_text = f"{len(text_inputs)} templates used:\n" + trailing_whitespace = False + for (text0, text1), id in zip(text_inputs, template_ids): + templates_text += ( + f'***---TEMPLATE "{id}"---***\n' + f"{'false' if label else 'true'}:\n" + f'"""{text0}"""\n' + f"{'true' if label else 'false'}:\n" + f'"""{text1}"""\n\n\n' + ) + if text0[-1].isspace() or text1[-1].isspace(): + trailing_whitespace = True + if trailing_whitespace: + logging.warning( + "Some inputs to the model have trailing whitespace! " + "Check that the jinja templates are not adding " + "trailing whitespace. If `token_loc` is 'last', this " + "will extract hidden states from the whitespace token." + ) + logging.info(templates_text) From 51736494a924b810607abdd95834ee27c8242934 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 13 Apr 2023 23:15:05 +0000 Subject: [PATCH 32/43] Fix label_column bug --- elk/extraction/extraction.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 602faf2d..17bacb87 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -101,15 +101,16 @@ def extract_hiddens( if rank != 0: logging.disable(logging.CRITICAL) - ds_names = cfg.prompts.datasets + p_cfg = cfg.prompts + ds_names = p_cfg.datasets assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." prompt_ds = load_prompts( ds_names[0], - label_column=cfg.prompts.label_columns[0], - num_classes=cfg.prompts.num_classes, + label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None, + num_classes=p_cfg.num_classes, split_type=split_type, - stream=cfg.prompts.stream, + stream=p_cfg.stream, rank=rank, world_size=world_size, ) # this dataset is already sharded, buqt hasn't been truncated to max_examples @@ -127,7 +128,7 @@ def extract_hiddens( # Iterating over questions layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) - global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1] + global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1] # break `max_examples` among the processes roughly equally max_examples = global_max_examples // world_size # the last process gets the remainder (which is usually small) @@ -286,7 +287,11 @@ def get_splits() -> SplitDict: info = get_dataset_config_info(ds_name, config_name or None) ds_features = assert_type(Features, info.features) - label_col = cfg.prompts.label_columns[0] or infer_label_column(ds_features) + label_col = ( + cfg.prompts.label_columns[0] + if cfg.prompts.label_columns + else infer_label_column(ds_features) + ) num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col]) num_variants = cfg.prompts.num_variants if num_variants < 0: From 3e6c39c3bffb670f81fcef4546f4f007cbe94d54 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Fri, 14 Apr 2023 03:51:28 +0000 Subject: [PATCH 33/43] GLUE MNLI works on Deberta --- elk/evaluation/evaluate.py | 1 + elk/metrics.py | 44 +---------------- elk/training/eigen_reporter.py | 16 +++++-- elk/training/reporter.py | 12 ++--- elk/training/supervised.py | 28 +++++------ elk/training/train.py | 87 +++++++++++++++++++--------------- 6 files changed, 82 insertions(+), 106 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 2150fbde..a55ce550 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -40,6 +40,7 @@ class Eval(Serializable): num_gpus: int = -1 min_gpu_mem: int | None = None skip_baseline: bool = False + concatenated_layer_offset: int = 0 def execute(self): datasets = self.data.prompts.datasets diff --git a/elk/metrics.py b/elk/metrics.py index 94a394d8..46a6113a 100644 --- a/elk/metrics.py +++ b/elk/metrics.py @@ -1,7 +1,3 @@ -from functools import partial -from typing import Literal - -from sklearn.metrics import average_precision_score, roc_auc_score from torch import Tensor @@ -37,42 +33,4 @@ def accuracy(y_true: Tensor, y_pred: Tensor) -> float: else: hard_preds = y_pred.argmax(-1) - return hard_preds.eq(y_true).float().mean().item() - - -def mean_auc(y_true: Tensor, y_scores: Tensor, curve: Literal["roc", "pr"]) -> float: - """ - Compute the mean area under the receiver operating curve (AUROC) or - precision-recall curve (average precision or mAP) for binary or multi-class - classification problems. - - Args: - y_true: Ground truth tensor of shape (N,) or (N, n_classes). - y_scores: Predicted probability tensor of shape (N,) for binary - or (N, n_classes) for multi-class. - curve: Type of curve to compute the mean AUC. Either 'pr' for - precision-recall curve or 'roc' for receiver operating - characteristic curve. Defaults to 'pr'. - - Returns: - float: Either mean AUROC or mean average precision (mAP). - """ - score_fn = { - "pr": average_precision_score, - "roc": partial(roc_auc_score, multi_class="ovo"), - }.get(curve, None) - - if score_fn is None: - raise ValueError("Invalid curve type. Supported values are 'pr' and 'roc'.") - - if len(y_scores.shape) == 1 or y_scores.shape[1] == 1: - return float(score_fn(y_true, y_scores.squeeze(1))) - else: - n_classes = y_scores.shape[1] - y_true_one_hot = to_one_hot(y_true, n_classes) - - return score_fn(y_true_one_hot, y_scores) - # return np.array([ - # score_fn(y_true_one_hot[:, i], y_scores[:, i]) - # for i in range(n_classes) - # ]).mean() + return hard_preds.cpu().eq(y_true.cpu()).float().mean().item() diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 9d7b713d..873a3559 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -227,13 +227,22 @@ def fit( loss = self.fit_streaming() if labels is not None: + (_, v, k, _) = hiddens.shape + hiddens = rearrange(hiddens, "n v k d -> (n v k) d") + labels = to_one_hot(repeat(labels, "n -> (n v)", v=v), k).flatten() + self.platt_scale(labels, hiddens) return loss def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100): - """Fit the scale and bias terms to data with LBFGS.""" + """Fit the scale and bias terms to data with LBFGS. + Args: + labels: Binary labels of shape [batch]. + hiddens: Hidden states of shape [batch, dim]. + max_iter: Maximum number of iterations for LBFGS. + """ opt = optim.LBFGS( [self.bias, self.scale], line_search_fn="strong_wolfe", @@ -241,14 +250,11 @@ def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100): tolerance_change=torch.finfo(hiddens.dtype).eps, tolerance_grad=torch.finfo(hiddens.dtype).eps, ) - (_, v, k, _) = hiddens.shape - labels = to_one_hot(repeat(labels, "n -> (n v)", v=v), k) def closure(): opt.zero_grad() - logits = rearrange(self(hiddens), "n v k -> (n v) k") loss = nn.functional.binary_cross_entropy_with_logits( - logits, labels.float() + self(hiddens), labels.float() ) loss.backward() diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 3fa8b772..c552f68d 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -13,7 +13,7 @@ from torch import Tensor from ..calibration import CalibrationError -from ..metrics import to_one_hot +from ..metrics import accuracy, to_one_hot from .classifier import Classifier @@ -165,13 +165,13 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: cal_err = 0.0 raw_preds = to_one_hot(logits.argmax(dim=-1), c).long() - auroc = roc_auc_score( - to_one_hot(Y, c).long().flatten().cpu(), logits.cpu().flatten() - ) - raw_acc = raw_preds.flatten().eq(Y).float().mean() + Y = to_one_hot(Y, c).long().flatten() + + auroc = roc_auc_score(Y.cpu(), logits.cpu().flatten()) + raw_acc = accuracy(Y, raw_preds.flatten()) return EvalResult( - acc=raw_acc.item(), + acc=float(raw_acc), cal_acc=cal_acc, auroc=float(auroc), ece=cal_err, diff --git a/elk/training/supervised.py b/elk/training/supervised.py index 8b6a723a..e0e9f441 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -1,8 +1,9 @@ import torch from einops import rearrange, repeat -from sklearn.metrics import accuracy_score, roc_auc_score +from sklearn.metrics import roc_auc_score from torch import Tensor +from ..metrics import accuracy, to_one_hot from ..utils import assert_type from .classifier import Classifier @@ -11,16 +12,16 @@ def evaluate_supervised( lr_model: Classifier, val_h: Tensor, val_labels: Tensor ) -> tuple[float, float]: (n, v, k, d) = val_h.shape - X_val = val_h.view(-1, d) + with torch.no_grad(): - lr_preds = lr_model(X_val).sigmoid().cpu() + logits = rearrange(lr_model(val_h).cpu().squeeze(), "n v k -> (n v) k") + raw_preds = to_one_hot(logits.argmax(dim=-1), k).long() - val_labels_aug = ( - torch.cat([val_labels, 1 - val_labels]).repeat_interleave(v) - ).cpu() + labels = repeat(val_labels, "n -> (n v)", v=v) + labels = to_one_hot(labels, k).flatten() - lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) - lr_auroc = roc_auc_score(val_labels_aug, lr_preds) + lr_acc = accuracy(labels, raw_preds.flatten()) + lr_auroc = roc_auc_score(labels.cpu(), logits.cpu().flatten()) return assert_type(float, lr_auroc), assert_type(float, lr_acc) @@ -28,15 +29,14 @@ def evaluate_supervised( def train_supervised(data: dict[str, tuple], device: str) -> Classifier: Xs, train_labels = [], [] - for x0, x1, labels, _ in data.values(): - (_, v, _) = x0.shape - x0 = rearrange(x0, "n v d -> (n v) d") - x1 = rearrange(x1, "n v d -> (n v) d") + for train_h, labels, _ in data.values(): + (_, v, k, _) = train_h.shape + train_h = rearrange(train_h, "n v k d -> (n v k) d") labels = repeat(labels, "n -> (n v)", v=v) - labels = torch.cat([labels, 1 - labels]) + labels = to_one_hot(labels, k).flatten() - Xs.append(torch.cat([x0, x1]).squeeze()) + Xs.append(train_h) train_labels.append(labels) X, train_labels = torch.cat(Xs), torch.cat(train_labels) diff --git a/elk/training/train.py b/elk/training/train.py index 4d512377..a787c7b0 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -86,26 +86,37 @@ def train_reporter( train_dict = self.prepare_data(device, layer, "train") val_dict = self.prepare_data(device, layer, "val") - # Can't figure out a way to make this line less ugly - hidden_size = next(iter(train_dict.values()))[0].shape[-1] - reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - pseudo_clf = self.get_pseudo_classifier(train_dict, device) + (train_h, train_labels, _), *rest = train_dict.values() + (n, v, k, d) = train_h.shape + + if not all(other_h.shape[2] == k for other_h, _, _ in rest): + raise ValueError("All datasets must have the same number of classes") + reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) if isinstance(self.cfg.net, CcsReporterConfig): assert len(train_dict) == 1, "CCS only supports single-task training" - reporter = CcsReporter(self.cfg.net, hidden_size, device=device) - (train_h, labels, _) = next(iter(train_dict.values())) - train_loss = reporter.fit(train_h, labels) + reporter = CcsReporter(self.cfg.net, d, device=device) + train_loss = reporter.fit(train_h, train_labels) elif isinstance(self.cfg.net, EigenReporterConfig): # To enable training on multiple tasks with different numbers of variants, # we update the statistics in a streaming fashion and then fit - reporter = EigenReporter(self.cfg.net, hidden_size, device=device) - for ds_name, (val_h, labels, _) in train_dict.items(): - reporter.update(val_h) + reporter = EigenReporter(self.cfg.net, d, k, device=device) + + hidden_list, label_list = [], [] + for ds_name, (train_h, train_labels, _) in train_dict.items(): + hidden_list.append(train_h) + label_list.append(train_labels) + reporter.update(train_h) train_loss = reporter.fit_streaming() + reporter.platt_scale( + to_one_hot( + repeat(torch.cat(label_list), "n -> (n v)", v=v), k + ).flatten(), + rearrange(torch.cat(hidden_list), "n v k d -> (n v k) d"), + ) else: raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") @@ -122,22 +133,25 @@ def train_reporter( for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items(): val_result = reporter.score(val_gt, val_h) with torch.no_grad(): - (n, v, k, d) = val_h.shape - - pseudo_preds = pseudo_clf( - # n v k d -> (n v k) d - rearrange(val_h, "n v k d -> (n v k) d") - ) - pseudo_labels = torch.cat( - [ - val_h.new_zeros(n), - val_h.new_ones(n), - ] - ) - pseudo_labels = repeat(pseudo_labels, "n -> (n v)", v=v) - pseudo_auroc = float( - roc_auc_score(pseudo_labels.cpu(), pseudo_preds.cpu()) - ) + if k == 2: + pseudo_clf = self.get_pseudo_classifier(train_dict, device) + pseudo_preds = pseudo_clf( + # n v k d -> (n v k) d + rearrange(val_h, "n v k d -> (n v k) d") + ) + pseudo_labels = torch.cat( + [ + val_h.new_zeros(n), + val_h.new_ones(n), + ] + ) + pseudo_labels = repeat(pseudo_labels, "n -> (n v)", v=v) + pseudo_auroc = float( + roc_auc_score(pseudo_labels.cpu(), pseudo_preds.cpu()) + ) + else: + # We don't bother with computing the pseudo-AUROC for multi-class + pseudo_auroc = None if val_lm_preds is not None: val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu() @@ -174,21 +188,18 @@ def train_reporter( def get_pseudo_classifier(self, data: dict[str, tuple], device: str) -> Classifier: """Check the separability of the pseudo-labels at a given layer.""" - x0s, x1s = [], [] - for x0, x1, _, _ in data.values(): - x0s.append(rearrange(x0, "n v d -> (n v) d")) - x1s.append(rearrange(x1, "n v d -> (n v) d")) + X = torch.cat( + [rearrange(h, "n v k d -> (n v) k d") for h, _, _ in data.values()] + ) + (N, k, d) = X.shape + assert k == 2, "Pseudo-labels should be binary" # Simple de-meaning normalization - X0 = torch.cat(x0s) - X1 = torch.cat(x1s) - X0 -= X0.mean(dim=0) - X1 -= X1.mean(dim=0) - - X = torch.cat([X0, X1]) - Y = torch.cat([X0.new_zeros(X0.shape[0]), X0.new_ones(X1.shape[0])]) + X -= X.mean(dim=0) + X = rearrange(X, "N k d -> (N k) d") + Y = torch.cat([X.new_zeros(N), X.new_ones(N)]) - pseudo_clf = Classifier(X.shape[-1], device=device) + pseudo_clf = Classifier(d, device=device) pseudo_clf.fit(X, Y) return pseudo_clf From 1e9ce06bd1ca50b345ec14b5321c1aea587fc42d Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Fri, 14 Apr 2023 05:05:44 +0000 Subject: [PATCH 34/43] Move pseudo AUROC stuff to CcsReporter --- elk/evaluation/evaluate.py | 12 +++--- elk/training/ccs_reporter.py | 75 ++++++++++++++++++++++++++++------- elk/training/reporter.py | 49 ----------------------- elk/training/train.py | 76 +++++++++++------------------------- 4 files changed, 90 insertions(+), 122 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 186e10a4..5cd4dff1 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -28,6 +28,8 @@ class Eval(Serializable): `elk.training.preprocessing.normalize()` for details. num_gpus: The number of GPUs to use. Defaults to -1, which means "use all available GPUs". + skip_supervised: Whether to skip training the supervised classifier. Defaults to + False. debug: When in debug mode, a useful log file is saved to the memorably-named output directory. Defaults to False. """ @@ -35,12 +37,12 @@ class Eval(Serializable): data: Extract source: str = field(positional=True) + concatenated_layer_offset: int = 0 debug: bool = False - out_dir: Path | None = None - num_gpus: int = -1 min_gpu_mem: int | None = None - skip_baseline: bool = False - concatenated_layer_offset: int = 0 + num_gpus: int = -1 + out_dir: Path | None = None + skip_supervised: bool = False def execute(self): datasets = self.data.prompts.datasets @@ -86,7 +88,7 @@ def evaluate_reporter( ) lr_dir = experiment_dir / "lr_models" - if not self.cfg.skip_baseline and lr_dir.exists(): + if not self.cfg.skip_supervised and lr_dir.exists(): with open(lr_dir / f"layer_{layer}.pt", "rb") as f: lr_model = torch.load(f, map_location=device).eval() diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 5facf7d3..b3043766 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -7,11 +7,13 @@ import torch import torch.nn as nn +from sklearn.metrics import roc_auc_score from torch import Tensor from torch.nn.functional import binary_cross_entropy as bce from ..parsing import parse_loss from ..utils.typing import assert_type +from .classifier import Classifier from .losses import LOSSES from .normalizer import Normalizer from .reporter import Reporter, ReporterConfig @@ -55,7 +57,6 @@ class CcsReporterConfig(ReporterConfig): init: Literal["default", "pca", "spherical", "zero"] = "default" loss: list[str] = field(default_factory=lambda: ["ccs"]) loss_dict: dict[str, float] = field(default_factory=dict, init=False) - normalization: Literal["none", "meanonly", "full"] = "full" num_layers: int = 1 pre_ln: bool = False seed: int = 42 @@ -96,12 +97,8 @@ def __init__( hidden_size = cfg.hidden_size or 4 * in_features // 3 - self.neg_norm = Normalizer( - (in_features,), device=device, dtype=dtype, mode=cfg.normalization - ) - self.pos_norm = Normalizer( - (in_features,), device=device, dtype=dtype, mode=cfg.normalization - ) + self.neg_norm = Normalizer((in_features,), device=device, dtype=dtype) + self.pos_norm = Normalizer((in_features,), device=device, dtype=dtype) self.probe = nn.Sequential( nn.Linear( @@ -131,6 +128,56 @@ def __init__( ) ) + def check_separability( + self, + train_pair: tuple[Tensor, Tensor], + val_pair: tuple[Tensor, Tensor], + ) -> float: + """Measure how linearly separable the pseudo-labels are for a contrast pair. + + Args: + train_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + contrastive representations. Used for training the classifier. + val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + contrastive representations. Used for evaluating the classifier. + + Returns: + The AUROC of a linear classifier fit on the pseudo-labels. + """ + _x0, _x1 = train_pair + _val_x0, _val_x1 = val_pair + + x0, x1 = self.neg_norm(_x0), self.pos_norm(_x1) + val_x0, val_x1 = self.neg_norm(_val_x0), self.pos_norm(_val_x1) + + pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore + pseudo_train_labels = torch.cat( + [ + x0.new_zeros(x0.shape[0]), + x0.new_ones(x0.shape[0]), + ] + ).repeat_interleave( + x0.shape[1] + ) # make num_variants copies of each pseudo-label + pseudo_val_labels = torch.cat( + [ + val_x0.new_zeros(val_x0.shape[0]), + val_x0.new_ones(val_x0.shape[0]), + ] + ).repeat_interleave(val_x0.shape[1]) + + pseudo_clf.fit( + # b v d -> (b v) d + torch.cat([x0, x1]).flatten(0, 1), + pseudo_train_labels, + ) + with torch.no_grad(): + pseudo_preds = pseudo_clf( + # b v d -> (b v) d + torch.cat([val_x0, val_x1]).flatten(0, 1) + ) + return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu())) + def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: loss = sum( LOSSES[name](logit0, logit1, coef) @@ -175,9 +222,9 @@ def forward(self, x: Tensor) -> Tensor: """Return the raw score output of the probe on `x`.""" return self.probe(x).squeeze(-1) - def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: - x_pos = self.pos_norm(x_pos) + def predict(self, x_neg: Tensor, x_pos: Tensor) -> Tensor: x_neg = self.neg_norm(x_neg) + x_pos = self.pos_norm(x_pos) return 0.5 * (self(x_pos).sigmoid() + (1 - self(x_neg).sigmoid())) def loss( @@ -226,14 +273,14 @@ def loss( def fit( self, - x_pos: Tensor, x_neg: Tensor, + x_pos: Tensor, labels: Optional[Tensor] = None, ) -> float: - """Fit the probe to the contrast pair (x0, x1). + """Fit the probe to the contrast pair (neg, pos). Args: - contrast_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + contrast_pair: A tuple of tensors, (neg, pos), where x0 and x1 are the contrastive representations. labels: The labels of the contrast pair. Defaults to None. @@ -280,8 +327,8 @@ def fit( def train_loop_adam( self, - x_pos: Tensor, x_neg: Tensor, + x_pos: Tensor, labels: Optional[Tensor] = None, ) -> float: """Adam train loop, returning the final loss. Modifies params in-place.""" @@ -302,8 +349,8 @@ def train_loop_adam( def train_loop_lbfgs( self, - x_pos: Tensor, x_neg: Tensor, + x_pos: Tensor, labels: Optional[Tensor] = None, ) -> float: """LBFGS train loop, returning the final loss. Modifies params in-place.""" diff --git a/elk/training/reporter.py b/elk/training/reporter.py index b168acb5..ea607650 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -12,7 +12,6 @@ from torch import Tensor from ..calibration import CalibrationError -from .classifier import Classifier class EvalResult(NamedTuple): @@ -68,54 +67,6 @@ class Reporter(nn.Module, ABC): n: Tensor config: ReporterConfig - @classmethod - def check_separability( - cls, - train_pair: tuple[Tensor, Tensor], - val_pair: tuple[Tensor, Tensor], - ) -> float: - """Measure how linearly separable the pseudo-labels are for a contrast pair. - - Args: - train_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the - contrastive representations. Used for training the classifier. - val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the - contrastive representations. Used for evaluating the classifier. - - Returns: - The AUROC of a linear classifier fit on the pseudo-labels. - """ - x0, x1 = train_pair - val_x0, val_x1 = val_pair - - pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore - pseudo_train_labels = torch.cat( - [ - x0.new_zeros(x0.shape[0]), - x0.new_ones(x0.shape[0]), - ] - ).repeat_interleave( - x0.shape[1] - ) # make num_variants copies of each pseudo-label - pseudo_val_labels = torch.cat( - [ - val_x0.new_zeros(val_x0.shape[0]), - val_x0.new_ones(val_x0.shape[0]), - ] - ).repeat_interleave(val_x0.shape[1]) - - pseudo_clf.fit( - # b v d -> (b v) d - torch.cat([x0, x1]).flatten(0, 1), - pseudo_train_labels, - ) - with torch.no_grad(): - pseudo_preds = pseudo_clf( - # b v d -> (b v) d - torch.cat([val_x0, val_x1]).flatten(0, 1) - ) - return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu())) - def reset_parameters(self): """Reset the parameters of the probe.""" diff --git a/elk/training/train.py b/elk/training/train.py index 111e2fd9..66d65ef6 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -3,11 +3,10 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Optional +from typing import Callable import pandas as pd import torch -from einops import rearrange, repeat from simple_parsing import Serializable, field, subgroups from sklearn.metrics import accuracy_score, roc_auc_score @@ -17,7 +16,6 @@ from ..utils import select_usable_devices from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig -from .classifier import Classifier from .eigen_reporter import EigenReporter, EigenReporterConfig from .reporter import OptimConfig, ReporterConfig @@ -34,7 +32,7 @@ class Elicit(Serializable): "use all available GPUs". normalization: The normalization method to use. Defaults to "meanonly". See `elk.training.preprocessing.normalize()` for details. - skip_baseline: Whether to skip training the supervised classifier. Defaults to + skip_supervised: Whether to skip training the supervised classifier. Defaults to False. debug: When in debug mode, a useful log file is saved to the memorably-named output directory. Defaults to False. @@ -46,13 +44,12 @@ class Elicit(Serializable): ) optim: OptimConfig = field(default_factory=OptimConfig) - num_gpus: int = -1 - min_gpu_mem: int | None = None - skip_baseline: bool = False concatenated_layer_offset: int = 0 - # if nonzero, appends the hidden states of layer concatenated_layer_offset before debug: bool = False - out_dir: Optional[Path] = None + min_gpu_mem: int | None = None + num_gpus: int = -1 + out_dir: Path | None = None + skip_supervised: bool = False def execute(self): train_run = Train(cfg=self, out_dir=self.out_dir) @@ -89,7 +86,6 @@ def train_reporter( # Can't figure out a way to make this line less ugly hidden_size = next(iter(train_dict.values()))[0].shape[-1] reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - pseudo_clf = self.get_pseudo_classifier(train_dict, device) if isinstance(self.cfg.net, CcsReporterConfig): assert len(train_dict) == 1, "CCS only supports single-task training" @@ -98,6 +94,11 @@ def train_reporter( (x0, x1, labels, _) = next(iter(train_dict.values())) train_loss = reporter.fit(x0, x1, labels) + (val_x0, val_x1, val_gt, _) = next(iter(val_dict.values())) + pseudo_auroc = reporter.check_separability( + train_pair=(x0, x1), val_pair=(val_x0, val_x1) + ) + elif isinstance(self.cfg.net, EigenReporterConfig): # To enable training on multiple tasks with different numbers of variants, # we update the statistics in a streaming fashion and then fit @@ -105,6 +106,7 @@ def train_reporter( for ds_name, (x0, x1, labels, _) in train_dict.items(): reporter.update(x0, x1) + pseudo_auroc = None train_loss = reporter.fit_streaming() else: raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") @@ -114,9 +116,12 @@ def train_reporter( torch.save(reporter, file) # Fit supervised logistic regression model - lr_model = train_supervised(train_dict, device=device) - with open(lr_dir / f"layer_{layer}.pt", "wb") as file: - torch.save(lr_model, file) + if not self.cfg.skip_supervised: + lr_model = train_supervised(train_dict, device=device) + with open(lr_dir / f"layer_{layer}.pt", "wb") as file: + torch.save(lr_model, file) + else: + lr_model = None row_buf = [] for ds_name, (val_x0, val_x1, val_gt, val_lm_preds) in val_dict.items(): @@ -125,23 +130,6 @@ def train_reporter( val_x0, val_x1, ) - with torch.no_grad(): - (n, v, d) = val_x0.shape - - pseudo_preds = pseudo_clf( - # b v d -> (b v) d - torch.cat([val_x0, val_x1]).flatten(0, 1) - ) - pseudo_labels = torch.cat( - [ - val_x0.new_zeros(n), - val_x0.new_ones(n), - ] - ) - pseudo_labels = repeat(pseudo_labels, "n -> (n v)", v=v) - pseudo_auroc = float( - roc_auc_score(pseudo_labels.cpu(), pseudo_preds.cpu()) - ) if val_lm_preds is not None: val_gt_cpu = ( @@ -167,35 +155,15 @@ def train_reporter( } ) - lr_auroc, lr_acc = evaluate_supervised(lr_model, val_x0, val_x1, val_gt) + if lr_model is not None: + row["lr_auroc"], row["lr_acc"] = evaluate_supervised( + lr_model, val_x0, val_x1, val_gt + ) - row["lr_auroc"] = lr_auroc - row["lr_acc"] = lr_acc row_buf.append(row) return pd.DataFrame(row_buf) - def get_pseudo_classifier(self, data: dict[str, tuple], device: str) -> Classifier: - """Check the separability of the pseudo-labels at a given layer.""" - - x0s, x1s = [], [] - for x0, x1, _, _ in data.values(): - x0s.append(rearrange(x0, "n v d -> (n v) d")) - x1s.append(rearrange(x1, "n v d -> (n v) d")) - - # Simple de-meaning normalization - X0 = torch.cat(x0s) - X1 = torch.cat(x1s) - X0 -= X0.mean(dim=0) - X1 -= X1.mean(dim=0) - - X = torch.cat([X0, X1]) - Y = torch.cat([X0.new_zeros(X0.shape[0]), X0.new_ones(X1.shape[0])]) - - pseudo_clf = Classifier(X.shape[-1], device=device) - pseudo_clf.fit(X, Y) - return pseudo_clf - def train(self): """Train a reporter on each layer of the network.""" devices = select_usable_devices(self.cfg.num_gpus) From 35a8f3479ca5f049e465f52649d920f4389dd2cb Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Fri, 14 Apr 2023 05:14:35 +0000 Subject: [PATCH 35/43] Make 'datasets' and 'label_columns' config options more opinionated --- elk/extraction/prompt_loading.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index b13dff69..6e10002a 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -76,8 +76,19 @@ def explode(self) -> list["PromptConfig"]: # Broadcast the dataset name to all data_dirs and label_columns if len(self.data_dirs) == 1: self.data_dirs *= len(self.datasets) + elif self.data_dirs and len(self.data_dirs) != len(self.datasets): + raise ValueError( + "data_dirs should be a list of length 0, 1, or len(datasets)," + f"but got {len(self.data_dirs)}" + ) + if len(self.label_columns) == 1: self.label_columns *= len(self.datasets) + elif self.label_columns and len(self.label_columns) != len(self.datasets): + raise ValueError( + "label_columns should be a list of length 0, 1, or len(datasets)," + f"but got {len(self.label_columns)}" + ) for ds, data_dir, col in zip_longest( self.datasets, self.data_dirs, self.label_columns From 615bbb15b86acbe49ef3481e24a5cdc57ad14e0f Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Fri, 14 Apr 2023 05:24:33 +0000 Subject: [PATCH 36/43] tiny spacing change --- elk/extraction/prompt_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 6e10002a..44a41c82 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -79,7 +79,7 @@ def explode(self) -> list["PromptConfig"]: elif self.data_dirs and len(self.data_dirs) != len(self.datasets): raise ValueError( "data_dirs should be a list of length 0, 1, or len(datasets)," - f"but got {len(self.data_dirs)}" + f" but got {len(self.data_dirs)}" ) if len(self.label_columns) == 1: @@ -87,7 +87,7 @@ def explode(self) -> list["PromptConfig"]: elif self.label_columns and len(self.label_columns) != len(self.datasets): raise ValueError( "label_columns should be a list of length 0, 1, or len(datasets)," - f"but got {len(self.label_columns)}" + f" but got {len(self.label_columns)}" ) for ds, data_dir, col in zip_longest( From f021404d66984a25f8f7a19d4af365bb69a44f82 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Fri, 14 Apr 2023 06:06:20 +0000 Subject: [PATCH 37/43] Allow for toggling CV --- elk/evaluation/evaluate.py | 3 +-- elk/training/supervised.py | 7 +++++-- elk/training/train.py | 15 +++++++++------ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 5cd4dff1..8fe8e169 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -28,8 +28,7 @@ class Eval(Serializable): `elk.training.preprocessing.normalize()` for details. num_gpus: The number of GPUs to use. Defaults to -1, which means "use all available GPUs". - skip_supervised: Whether to skip training the supervised classifier. Defaults to - False. + skip_supervised: Whether to skip evaluation of the supervised classifier. debug: When in debug mode, a useful log file is saved to the memorably-named output directory. Defaults to False. """ diff --git a/elk/training/supervised.py b/elk/training/supervised.py index b8f580f3..fac7152c 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -26,7 +26,7 @@ def evaluate_supervised( return assert_type(float, lr_auroc), assert_type(float, lr_acc) -def train_supervised(data: dict[str, tuple], device: str) -> Classifier: +def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier: Xs, train_labels = [], [] for x0, x1, labels, _ in data.values(): @@ -42,6 +42,9 @@ def train_supervised(data: dict[str, tuple], device: str) -> Classifier: X, train_labels = torch.cat(Xs), torch.cat(train_labels) lr_model = Classifier(X.shape[-1], device=device) - lr_model.fit_cv(X, train_labels) + if cv: + lr_model.fit_cv(X, train_labels) + else: + lr_model.fit(X, train_labels) return lr_model diff --git a/elk/training/train.py b/elk/training/train.py index 66d65ef6..403f8ce9 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable +from typing import Callable, Literal import pandas as pd import torch @@ -32,8 +32,9 @@ class Elicit(Serializable): "use all available GPUs". normalization: The normalization method to use. Defaults to "meanonly". See `elk.training.preprocessing.normalize()` for details. - skip_supervised: Whether to skip training the supervised classifier. Defaults to - False. + supervised: Whether to train a supervised classifier, and if so, whether to + use cross-validation. Defaults to "single", which means to train a single + classifier on the training data. "cv" means to use cross-validation. debug: When in debug mode, a useful log file is saved to the memorably-named output directory. Defaults to False. """ @@ -49,7 +50,7 @@ class Elicit(Serializable): min_gpu_mem: int | None = None num_gpus: int = -1 out_dir: Path | None = None - skip_supervised: bool = False + supervised: Literal["none", "single", "cv"] = "single" def execute(self): train_run = Train(cfg=self, out_dir=self.out_dir) @@ -116,8 +117,10 @@ def train_reporter( torch.save(reporter, file) # Fit supervised logistic regression model - if not self.cfg.skip_supervised: - lr_model = train_supervised(train_dict, device=device) + if self.cfg.supervised != "none": + lr_model = train_supervised( + train_dict, device=device, cv=self.cfg.supervised == "cv" + ) with open(lr_dir / f"layer_{layer}.pt", "wb") as file: torch.save(lr_model, file) else: From 99f01c3596c9c1d29f5e68338ad1cefe9b3b89ec Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Fri, 14 Apr 2023 07:21:42 +0000 Subject: [PATCH 38/43] Remove duplicate dbpedia template --- .../templates/dbpedia_14/templates.yaml | 19 ------------------- elk/training/eigen_reporter.py | 3 --- tests/test_eigen_reporter.py | 15 +++++++-------- 3 files changed, 7 insertions(+), 30 deletions(-) diff --git a/elk/promptsource/templates/dbpedia_14/templates.yaml b/elk/promptsource/templates/dbpedia_14/templates.yaml index c16b0e70..4a19bb61 100644 --- a/elk/promptsource/templates/dbpedia_14/templates.yaml +++ b/elk/promptsource/templates/dbpedia_14/templates.yaml @@ -53,25 +53,6 @@ templates: original_task: true name: burns_3 reference: Burns et al. - 03fa401f-3329-48fa-be4a-1b6725292ee6: !Template - answer_choices: Company ||| Educational Institution ||| Artist ||| Athlete ||| - Office Holder ||| Mean Of Transportation ||| Building ||| Natural Place ||| - Village ||| Animal ||| Plant ||| Album ||| Film ||| Written Work - id: 03fa401f-3329-48fa-be4a-1b6725292ee6 - jinja: 'Consider the following example: '''''' {{content}} '''''' - - Which is the topic of this example, choice 1: {{answer_choices[label]}}, or - choice 2: {{answer_choices[1 - label]}} ||| - - {{answer_choices[label]}}' - metadata: !TemplateMetadata - choices_in_prompt: true - languages: - - en - metrics: [] - original_task: true - name: burns_4 - reference: Burns et al. 04fa401f-3329-48fa-be4a-1b6725292ee6: !Template answer_choices: Company ||| Educational Institution ||| Artist ||| Athlete ||| Office Holder ||| Mean Of Transportation ||| Building ||| Natural Place ||| diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 873a3559..ca8905ac 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -144,9 +144,6 @@ def clear(self) -> None: def update(self, hiddens: Tensor) -> None: (n, _, k, d) = hiddens.shape - # Zero out shared info - hiddens = hiddens - hiddens.mean(dim=2, keepdim=True) - # Sanity checks assert k > 1, "Must provide at least two hidden states" assert hiddens.ndim == 4, "Must be of shape [batch, variants, choices, dim]" diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index e576069a..44416bb9 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -9,19 +9,18 @@ def test_eigen_reporter(): hidden_size = 10 num_clusters = 100 - x_pos = torch.randn(num_clusters, cluster_size, hidden_size, dtype=torch.float64) - x_neg = torch.randn(num_clusters, cluster_size, hidden_size, dtype=torch.float64) - x_pos1, x_pos2 = x_pos.chunk(2, dim=0) - x_neg1, x_neg2 = x_neg.chunk(2, dim=0) + x = torch.randn(num_clusters, cluster_size, 2, hidden_size, dtype=torch.float64) + x1, x2 = x.chunk(2, dim=0) reporter = EigenReporter(EigenReporterConfig(), hidden_size, dtype=torch.float64) - reporter.update(x_pos1, x_neg1) - reporter.update(x_pos2, x_neg2) + reporter.update(x1) + reporter.update(x2) # Check that the streaming mean is correct + x_neg, x_pos = x.unbind(2) pos_mu, neg_mu = x_pos.mean(dim=(0, 1)), x_neg.mean(dim=(0, 1)) - torch.testing.assert_close(reporter.class_means[0], pos_mu) - torch.testing.assert_close(reporter.class_means[1], neg_mu) + torch.testing.assert_close(reporter.class_means[0], neg_mu) + torch.testing.assert_close(reporter.class_means[1], pos_mu) # Check that the streaming covariance is correct pos_centroids, neg_centroids = x_pos.mean(dim=1), x_neg.mean(dim=1) From d16c96b943b72b1f33e74f255518690f2f4e5d30 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 15 Apr 2023 20:22:47 +0000 Subject: [PATCH 39/43] Training on datasets with different numbers of classes now works --- elk/training/__init__.py | 2 + elk/training/eigen_reporter.py | 57 +++++++++++++++++---------- elk/training/train.py | 40 ++++++++++--------- tests/test_eigen_reporter.py | 72 ++++++++++++++++++++++++---------- 4 files changed, 111 insertions(+), 60 deletions(-) diff --git a/elk/training/__init__.py b/elk/training/__init__.py index 6428c3c5..ce6e4d48 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -1,4 +1,5 @@ from .ccs_reporter import CcsReporter, CcsReporterConfig +from .classifier import Classifier from .eigen_reporter import EigenReporter, EigenReporterConfig from .normalizer import Normalizer from .reporter import OptimConfig, Reporter, ReporterConfig @@ -6,6 +7,7 @@ __all__ = [ "CcsReporter", "CcsReporterConfig", + "Classifier", "EigenReporter", "EigenReporterConfig", "Normalizer", diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index ca8905ac..e45b0215 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -40,20 +40,27 @@ class EigenReporter(Reporter): """A linear reporter whose weights are computed via eigendecomposition. Args: - in_features: The number of input features. cfg: The reporter configuration. + in_features: The number of input features. + num_classes: The number of classes for tracking the running means. If `None`, + we don't track the running means at all, and the semantics of `update()` + are a bit different. In particular, each call to `update()` is treated as a + new dataset, with a potentially different number of classes. The covariance + matrices are simply averaged over each batch of data passed to `update()`, + instead of being updated with Welford's algorithm. This is useful for + training a single reporter on multiple datasets, where the number of + classes may vary. Attributes: config: The reporter configuration. - intercluster_cov_M2: The running sum of the covariance matrices of the - centroids of the positive and negative clusters. + intercluster_cov_M2: The unnormalized covariance matrix averaged over all + classes. intracluster_cov: The running mean of the covariance matrices within each cluster. This doesn't need to be a running sum because it's doesn't use Welford's algorithm. - contrastive_xcov_M2: The running sum of the cross-covariance between the - centroids of the positive and negative clusters. - n: The running sum of the number of samples in the positive and negative - clusters. + contrastive_xcov_M2: Average of the unnormalized cross-covariance matrices + across all pairs of classes (k, k'). + n: The running sum of the number of clusters processed by `update()`. weight: The reporter weight matrix. Guaranteed to always be orthogonal, and the columns are sorted in descending order of eigenvalue magnitude. """ @@ -64,16 +71,17 @@ class EigenReporter(Reporter): intracluster_cov: Tensor # invariance contrastive_xcov_M2: Tensor # negative covariance n: Tensor - class_means: Tensor + class_means: Tensor | None weight: Tensor def __init__( self, cfg: EigenReporterConfig, in_features: int, - num_classes: int = 2, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, + num_classes: int | None = 2, + *, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, ): super().__init__() self.config = cfg @@ -86,7 +94,11 @@ def __init__( self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long)) self.register_buffer( "class_means", - torch.zeros(num_classes, in_features, device=device, dtype=dtype), + ( + torch.zeros(num_classes, in_features, device=device, dtype=dtype) + if num_classes is not None + else None + ), ) self.register_buffer( @@ -148,10 +160,6 @@ def update(self, hiddens: Tensor) -> None: assert k > 1, "Must provide at least two hidden states" assert hiddens.ndim == 4, "Must be of shape [batch, variants, choices, dim]" - # We don't actually call super because we need access to the earlier estimate - # of the population mean in order to update (cross-)covariances properly - # super().update(hiddens) - self.n += n # *** Invariance (intra-cluster) *** @@ -164,21 +172,28 @@ def update(self, hiddens: Tensor) -> None: centroids = hiddens.mean(1) deltas, deltas2 = [], [] + # Iterating over classes for i, h in enumerate(centroids.unbind(1)): - # Update the running means; super().update() does this usually - delta = h - self.class_means[i] - self.class_means[i] += delta.sum(dim=0) / self.n + # Update the running means if needed + if self.class_means is not None: + delta = h - self.class_means[i] + self.class_means[i] += delta.sum(dim=0) / self.n + + # Post-mean update deltas are used to update the (co)variance + delta2 = h - self.class_means[i] # [n, d] + else: + delta = h - h.mean(dim=0) + delta2 = delta # *** Variance (inter-cluster) *** # See code at https://bit.ly/3YC9BhH and "Welford's online algorithm" # in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. - # Post-mean update deltas are used to update the (co)variance - delta2 = h - self.class_means[i] # [n, d] self.intercluster_cov_M2.addmm_(delta.mT, delta2, alpha=1 / k) deltas.append(delta) deltas2.append(delta2) # *** Negative covariance (contrastive) *** + # Iterating over pairs of classes (k, k') where k != k' for i, d in enumerate(deltas): for j, d_ in enumerate(deltas2): # Compare to the other classes only diff --git a/elk/training/train.py b/elk/training/train.py index 297b9d21..510f3cc4 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -85,46 +85,48 @@ def train_reporter( train_dict = self.prepare_data(device, layer, "train") val_dict = self.prepare_data(device, layer, "val") - (train_h, train_labels, _), *rest = train_dict.values() - (n, v, k, d) = train_h.shape + (first_train_h, train_labels, _), *rest = train_dict.values() + d = first_train_h.shape[-1] + if not all(other_h.shape[-1] == d for other_h, _, _ in rest): + raise ValueError("All datasets must have the same hidden state size") - if not all(other_h.shape[2] == k for other_h, _, _ in rest): - raise ValueError("All datasets must have the same number of classes") - - # Can't figure out a way to make this line less ugly - next(iter(train_dict.values()))[0].shape[-1] reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) if isinstance(self.cfg.net, CcsReporterConfig): assert len(train_dict) == 1, "CCS only supports single-task training" reporter = CcsReporter(self.cfg.net, d, device=device) - train_loss = reporter.fit(train_h, train_labels) + train_loss = reporter.fit(first_train_h, train_labels) (val_h, val_gt, _) = next(iter(val_dict.values())) - x0, x1 = train_h.unbind(2) + x0, x1 = first_train_h.unbind(2) val_x0, val_x1 = val_h.unbind(2) pseudo_auroc = reporter.check_separability( train_pair=(x0, x1), val_pair=(val_x0, val_x1) ) elif isinstance(self.cfg.net, EigenReporterConfig): - # To enable training on multiple tasks with different numbers of variants, - # we update the statistics in a streaming fashion and then fit - reporter = EigenReporter(self.cfg.net, d, k, device=device) + # We set num_classes to None to enable training on datasets with different + # numbers of classes. Under the hood, this causes the covariance statistics + # to be simply averaged across all batches passed to update(). + reporter = EigenReporter(self.cfg.net, d, num_classes=None, device=device) hidden_list, label_list = [], [] for ds_name, (train_h, train_labels, _) in train_dict.items(): - hidden_list.append(train_h) - label_list.append(train_labels) + (_, v, k, _) = train_h.shape + + # Datasets can have different numbers of variants and different numbers + # of classes, so we need to flatten them here before concatenating + hidden_list.append(rearrange(train_h, "n v k d -> (n v k) d")) + label_list.append( + to_one_hot(repeat(train_labels, "n -> (n v)", v=v), k).flatten() + ) reporter.update(train_h) pseudo_auroc = None train_loss = reporter.fit_streaming() reporter.platt_scale( - to_one_hot( - repeat(torch.cat(label_list), "n -> (n v)", v=v), k - ).flatten(), - rearrange(torch.cat(hidden_list), "n v k d -> (n v k) d"), + torch.cat(label_list), + torch.cat(hidden_list), ) else: raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") @@ -148,6 +150,8 @@ def train_reporter( val_result = reporter.score(val_gt, val_h) if val_lm_preds is not None: + (_, v, k, _) = val_h.shape + val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu() val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...") val_lm_auroc = roc_auc_score( diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index 44416bb9..8977de01 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -1,39 +1,69 @@ +import pytest import torch from elk.training import EigenReporter, EigenReporterConfig from elk.utils import batch_cov, cov_mean_fused -def test_eigen_reporter(): +@pytest.mark.parametrize("track_class_means", [True, False]) +def test_eigen_reporter(track_class_means: bool): cluster_size = 5 hidden_size = 10 - num_clusters = 100 + N = 100 - x = torch.randn(num_clusters, cluster_size, 2, hidden_size, dtype=torch.float64) + x = torch.randn(N, cluster_size, 2, hidden_size, dtype=torch.float64) x1, x2 = x.chunk(2, dim=0) + x_neg, x_pos = x.unbind(2) - reporter = EigenReporter(EigenReporterConfig(), hidden_size, dtype=torch.float64) + reporter = EigenReporter( + EigenReporterConfig(), + hidden_size, + dtype=torch.float64, + num_classes=2 if track_class_means else None, + ) reporter.update(x1) reporter.update(x2) - # Check that the streaming mean is correct - x_neg, x_pos = x.unbind(2) - pos_mu, neg_mu = x_pos.mean(dim=(0, 1)), x_neg.mean(dim=(0, 1)) - torch.testing.assert_close(reporter.class_means[0], neg_mu) - torch.testing.assert_close(reporter.class_means[1], pos_mu) + if track_class_means: + # Check that the streaming mean is correct + neg_mu, pos_mu = x_neg.mean(dim=(0, 1)), x_pos.mean(dim=(0, 1)) - # Check that the streaming covariance is correct - pos_centroids, neg_centroids = x_pos.mean(dim=1), x_neg.mean(dim=1) - expected_var = 0.5 * (batch_cov(pos_centroids) + batch_cov(neg_centroids)) - torch.testing.assert_close(reporter.intercluster_cov, expected_var) + assert reporter.class_means is not None + torch.testing.assert_close(reporter.class_means[0], neg_mu) + torch.testing.assert_close(reporter.class_means[1], pos_mu) - # Check that the streaming invariance (intra-cluster variance) is correct - expected_invariance = 0.5 * (cov_mean_fused(x_pos) + cov_mean_fused(x_neg)) - torch.testing.assert_close(reporter.intracluster_cov, expected_invariance) + # Check that the streaming covariance is correct + neg_centroids, pos_centroids = x_neg.mean(dim=1), x_pos.mean(dim=1) + true_cov = 0.5 * (batch_cov(neg_centroids) + batch_cov(pos_centroids)) + torch.testing.assert_close(reporter.intercluster_cov, true_cov) + + # Check that the streaming negative covariance is correct + true_xcov = (neg_centroids - neg_mu).mT @ (pos_centroids - pos_mu) / N + true_xcov = 0.5 * (true_xcov + true_xcov.mT) + torch.testing.assert_close(reporter.contrastive_xcov, true_xcov) + else: + assert reporter.class_means is None - # Check that the streaming negative covariance is correct - cross_cov = (pos_centroids - pos_mu).mT @ (neg_centroids - neg_mu) / num_clusters - cross_cov = 0.5 * (cross_cov + cross_cov.mT) - torch.testing.assert_close(reporter.contrastive_xcov, cross_cov) + # Check that the covariance matrices are correct. When we don't track class + # means, we expect intercluster_cov and contrastive_xcov to simply be averaged + # over each batch passed to update(). + true_xcov = 0.0 + true_cov = 0.0 + for x_i in (x1, x2): + x_neg_i, x_pos_i = x_i.unbind(2) + neg_centroids, pos_centroids = x_neg_i.mean(dim=1), x_pos_i.mean(dim=1) + true_cov += 0.5 * (batch_cov(neg_centroids) + batch_cov(pos_centroids)) + + neg_mu_i, pos_mu_i = x_neg_i.mean(dim=(0, 1)), x_pos_i.mean(dim=(0, 1)) + xcov_asym = (neg_centroids - neg_mu_i).mT @ (pos_centroids - pos_mu_i) + true_xcov += 0.5 * (xcov_asym + xcov_asym.mT) + + torch.testing.assert_close(reporter.intercluster_cov, true_cov / 2) + torch.testing.assert_close(reporter.contrastive_xcov, true_xcov / N) + + # Check that the streaming invariance (intra-cluster variance) is correct. + # This is actually the same whether or not we track class means. + expected_invariance = 0.5 * (cov_mean_fused(x_neg) + cov_mean_fused(x_pos)) + torch.testing.assert_close(reporter.intracluster_cov, expected_invariance) - assert reporter.n == num_clusters + assert reporter.n == N From 044774ef46aa60d0748fb0437f76bfd7d1dcae3c Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 15 Apr 2023 22:01:10 +0000 Subject: [PATCH 40/43] Efficient bootstrap CIs for AUROCs --- .pre-commit-config.yaml | 2 +- elk/evaluation/evaluate.py | 7 +- elk/metrics.py | 122 +++++++++++++++++++++++++++++++++++ elk/run.py | 4 +- elk/training/ccs_reporter.py | 4 +- elk/training/reporter.py | 15 +++-- elk/training/supervised.py | 13 ++-- elk/training/train.py | 38 ++++++----- pyproject.toml | 5 +- tests/test_roc_auc.py | 53 +++++++++++++++ 10 files changed, 219 insertions(+), 44 deletions(-) create mode 100644 tests/test_roc_auc.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 38fff767..42e9a5f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,4 +24,4 @@ repos: hooks: - id: codespell # The promptsource templates spuriously get flagged without this - args: ["--skip=*.yaml"] + args: ["-L fpr", "--skip=*.yaml"] diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index d819c3fe..7ccb156d 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -85,9 +85,10 @@ def evaluate_reporter( with open(lr_dir / f"layer_{layer}.pt", "rb") as f: lr_model = torch.load(f, map_location=device).eval() - lr_auroc, lr_acc = evaluate_supervised(lr_model, val_h, val_gt) - - stats_row["lr_auroc"] = lr_auroc + lr_auroc_res, lr_acc = evaluate_supervised(lr_model, val_h, val_gt) + stats_row["lr_auroc"] = lr_auroc_res.estimate + stats_row["lr_auroc_lower"] = lr_auroc_res.lower + stats_row["lr_auroc_upper"] = lr_auroc_res.upper stats_row["lr_acc"] = lr_acc row_buf.append(stats_row) diff --git a/elk/metrics.py b/elk/metrics.py index 46a6113a..0150f02f 100644 --- a/elk/metrics.py +++ b/elk/metrics.py @@ -1,3 +1,6 @@ +from typing import NamedTuple + +import torch from torch import Tensor @@ -34,3 +37,122 @@ def accuracy(y_true: Tensor, y_pred: Tensor) -> float: hard_preds = y_pred.argmax(-1) return hard_preds.cpu().eq(y_true.cpu()).float().mean().item() + + +class RocAucResult(NamedTuple): + """Named tuple for storing ROC AUC results.""" + + estimate: float + """Point estimate of the ROC AUC computed on this sample.""" + lower: float + """Lower bound of the bootstrap confidence interval.""" + upper: float + """Upper bound of the bootstrap confidence interval.""" + + +def roc_auc(y_true: Tensor, y_pred: Tensor) -> Tensor: + """Area under the receiver operating characteristic curve (ROC AUC). + + Unlike scikit-learn's implementation, this function supports batched inputs of + shape `(N, n)` where `N` is the number of datasets and `n` is the number of samples + within each dataset. This is primarily useful for efficiently computing bootstrap + confidence intervals. + + Args: + y_true: Ground truth tensor of shape `(N,)` or `(N, n)`. + y_pred: Predicted class tensor of shape `(N,)` or `(N, n)`. + + Returns: + Tensor: If the inputs are 1D, a scalar containing the ROC AUC. If they're 2D, + a tensor of shape (N,) containing the ROC AUC for each dataset. + """ + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true and y_pred should have the same shape; " + f"got {y_true.shape} and {y_pred.shape}" + ) + if y_true.dim() not in (1, 2): + raise ValueError("y_true and y_pred should be 1D or 2D tensors") + + # Sort y_pred in descending order and get indices + indices = y_pred.argsort(descending=True, dim=-1) + + # Reorder y_true based on sorted y_pred indices + y_true_sorted = y_true.gather(-1, indices) + + # Calculate number of positive and negative samples + num_positives = y_true.sum(dim=-1) + num_negatives = y_true.shape[-1] - num_positives + + # Calculate cumulative sum of true positive counts (TPs) + tps = torch.cumsum(y_true_sorted, dim=-1) + + # Calculate cumulative sum of false positive counts (FPs) + fps = torch.cumsum(1 - y_true_sorted, dim=-1) + + # Calculate true positive rate (TPR) and false positive rate (FPR) + tpr = tps / num_positives.view(-1, 1) + fpr = fps / num_negatives.view(-1, 1) + + # Calculate differences between consecutive FPR values (widths of trapezoids) + fpr_diffs = torch.cat( + [fpr[..., 1:] - fpr[..., :-1], torch.zeros_like(fpr[..., :1])], dim=-1 + ) + + # Calculate area under the ROC curve for each dataset using trapezoidal rule + return torch.sum(tpr * fpr_diffs, dim=-1).squeeze() + + +def roc_auc_ci( + y_true: Tensor, + y_pred: Tensor, + *, + num_samples: int = 1000, + level: float = 0.95, + seed: int = 42, +) -> RocAucResult: + """Bootstrap confidence interval for the ROC AUC. + + Args: + y_true: Ground truth tensor of shape `(N,)`. + y_pred: Predicted class tensor of shape `(N,)`. + num_samples (int): Number of bootstrap samples to use. + level (float): Confidence level of the confidence interval. + seed (int): Random seed for reproducibility. + + Returns: + RocAucResult: Named tuple containing the lower and upper bounds of the + confidence interval, along with the point estimate. + """ + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true and y_pred should have the same shape; " + f"got {y_true.shape} and {y_pred.shape}" + ) + if y_true.dim() != 1: + raise ValueError("y_true and y_pred should be 1D tensors") + + device = y_true.device + N = y_true.shape[0] + + # Generate random indices for bootstrap samples (shape: [num_bootstraps, N]) + rng = torch.Generator(device=device).manual_seed(seed) + indices = torch.randint(0, N, (num_samples, N), device=device, generator=rng) + + # Create bootstrap samples of true labels and predicted probabilities + y_true_bootstraps = y_true[indices] + y_pred_bootstraps = y_pred[indices] + + # Compute ROC AUC scores for bootstrap samples + bootstrap_aucs = roc_auc(y_true_bootstraps, y_pred_bootstraps) + + # Calculate the lower and upper bounds of the confidence interval. We use + # nanquantile instead of quantile because some bootstrap samples may have + # NaN values due to the fact that they have only one class. + alpha = (1 - level) / 2 + q = y_pred.new_tensor([alpha, 1 - alpha]) + lower, upper = bootstrap_aucs.nanquantile(q).tolist() + + # Compute the point estimate + estimate = roc_auc(y_true, y_pred).item() + return RocAucResult(estimate, lower, upper) diff --git a/elk/run.py b/elk/run.py index 70af211b..bc889baa 100644 --- a/elk/run.py +++ b/elk/run.py @@ -86,7 +86,7 @@ def get_device(self, devices, world_size: int) -> str: def prepare_data( self, device: str, layer: int, split_type: Literal["train", "val"] - ) -> dict[str, tuple[Tensor, Tensor, np.ndarray | None]]: + ) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]: """Prepare data for the specified layer and split type.""" out = {} @@ -98,7 +98,7 @@ def prepare_data( labels = assert_type(Tensor, split["label"]) val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) - with split.formatted_as("numpy"): + with split.formatted_as("torch", device=device): has_preds = "model_preds" in split.features lm_preds = split["model_preds"] if has_preds else None diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index c4436aec..e655e195 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -7,10 +7,10 @@ import torch import torch.nn as nn -from sklearn.metrics import roc_auc_score from torch import Tensor from torch.nn.functional import binary_cross_entropy as bce +from ..metrics import roc_auc from ..parsing import parse_loss from ..utils.typing import assert_type from .classifier import Classifier @@ -176,7 +176,7 @@ def check_separability( # b v d -> (b v) d torch.cat([val_x0, val_x1]).flatten(0, 1) ) - return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu())) + return roc_auc(pseudo_val_labels, pseudo_preds).item() def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: loss = sum( diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 0b47b90b..39a1deda 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -9,11 +9,10 @@ import torch.nn as nn from einops import rearrange, repeat from simple_parsing.helpers import Serializable -from sklearn.metrics import roc_auc_score from torch import Tensor from ..calibration import CalibrationError -from ..metrics import accuracy, to_one_hot +from ..metrics import accuracy, roc_auc_ci, to_one_hot class EvalResult(NamedTuple): @@ -23,9 +22,12 @@ class EvalResult(NamedTuple): which contains the loss, accuracy, calibrated accuracy, and AUROC. """ + auroc: float + auroc_lower: float + auroc_upper: float + acc: float cal_acc: float - auroc: float ece: float @@ -117,12 +119,13 @@ def score(self, labels: Tensor, hiddens: Tensor) -> EvalResult: raw_preds = to_one_hot(logits.argmax(dim=-1), c).long() Y = to_one_hot(Y, c).long().flatten() - auroc = roc_auc_score(Y.cpu(), logits.cpu().flatten()) raw_acc = accuracy(Y, raw_preds.flatten()) - + auroc_result = roc_auc_ci(Y, logits.flatten()) return EvalResult( + auroc=auroc_result.estimate, + auroc_lower=auroc_result.lower, + auroc_upper=auroc_result.upper, acc=float(raw_acc), cal_acc=cal_acc, - auroc=float(auroc), ece=cal_err, ) diff --git a/elk/training/supervised.py b/elk/training/supervised.py index c7d2ca6e..300d6f0e 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -1,29 +1,28 @@ import torch from einops import rearrange, repeat -from sklearn.metrics import roc_auc_score from torch import Tensor -from ..metrics import accuracy, to_one_hot +from ..metrics import RocAucResult, accuracy, roc_auc_ci, to_one_hot from ..utils import assert_type from .classifier import Classifier def evaluate_supervised( lr_model: Classifier, val_h: Tensor, val_labels: Tensor -) -> tuple[float, float]: - (n, v, k, d) = val_h.shape +) -> tuple[RocAucResult, float]: + (_, v, k, _) = val_h.shape with torch.no_grad(): - logits = rearrange(lr_model(val_h).cpu().squeeze(), "n v k -> (n v) k") + logits = rearrange(lr_model(val_h).squeeze(), "n v k -> (n v) k") raw_preds = to_one_hot(logits.argmax(dim=-1), k).long() labels = repeat(val_labels, "n -> (n v)", v=v) labels = to_one_hot(labels, k).flatten() lr_acc = accuracy(labels, raw_preds.flatten()) - lr_auroc = roc_auc_score(labels.cpu(), logits.cpu().flatten()) + lr_auroc = roc_auc_ci(labels, logits.flatten()) - return assert_type(float, lr_auroc), assert_type(float, lr_acc) + return lr_auroc, assert_type(float, lr_acc) def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier: diff --git a/elk/training/train.py b/elk/training/train.py index 510f3cc4..cd9dcb7c 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -9,10 +9,9 @@ import torch from einops import rearrange, repeat from simple_parsing import Serializable, field, subgroups -from sklearn.metrics import roc_auc_score from ..extraction.extraction import Extract -from ..metrics import accuracy, to_one_hot +from ..metrics import accuracy, roc_auc_ci, to_one_hot from ..run import Run from ..training.supervised import evaluate_supervised, train_supervised from ..utils import select_usable_devices @@ -148,21 +147,6 @@ def train_reporter( row_buf = [] for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items(): val_result = reporter.score(val_gt, val_h) - - if val_lm_preds is not None: - (_, v, k, _) = val_h.shape - - val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu() - val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...") - val_lm_auroc = roc_auc_score( - to_one_hot(val_gt_cpu, k).long().flatten(), val_lm_preds.flatten() - ) - - val_lm_acc = accuracy(val_gt_cpu, torch.from_numpy(val_lm_preds)) - else: - val_lm_auroc = None - val_lm_acc = None - row = pd.Series( { "dataset": ds_name, @@ -170,15 +154,29 @@ def train_reporter( "pseudo_auroc": pseudo_auroc, "train_loss": train_loss, **val_result._asdict(), - "lm_auroc": val_lm_auroc, - "lm_acc": val_lm_acc, } ) + if val_lm_preds is not None: + (_, v, k, _) = val_h.shape + + val_gt_rep = repeat(val_gt, "n -> (n v)", v=v) + val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...") + val_lm_auroc_res = roc_auc_ci( + to_one_hot(val_gt_rep, k).long().flatten(), val_lm_preds.flatten() + ) + row["lm_auroc"] = val_lm_auroc_res.estimate + row["lm_auroc_lower"] = val_lm_auroc_res.lower + row["lm_auroc_upper"] = val_lm_auroc_res.upper + row["lm_acc"] = accuracy(val_gt_rep, val_lm_preds) + if lr_model is not None: - row["lr_auroc"], row["lr_acc"] = evaluate_supervised( + lr_auroc_res, row["lr_acc"] = evaluate_supervised( lr_model, val_h, val_gt ) + row["lr_auroc"] = lr_auroc_res.estimate + row["lr_auroc_lower"] = lr_auroc_res.lower + row["lr_auroc_upper"] = lr_auroc_res.upper row_buf.append(row) diff --git a/pyproject.toml b/pyproject.toml index f688416d..6575e57a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,8 +21,6 @@ dependencies = [ "pandas", # Basically any version should work as long as it supports the user's CUDA version "pynvml", - # Doesn't really matter but before 1.0.0 there might be weird breaking changes - "scikit-learn>=1.0.0", # Needed for certain HF tokenizers "sentencepiece==0.1.97", # We upstreamed bugfixes for Literal types in 0.1.1 @@ -43,7 +41,8 @@ dev = [ "hypothesis", "pre-commit", "pytest", - "pyright" + "pyright", + "scikit-learn", ] [project.scripts] diff --git a/tests/test_roc_auc.py b/tests/test_roc_auc.py new file mode 100644 index 00000000..244bdb88 --- /dev/null +++ b/tests/test_roc_auc.py @@ -0,0 +1,53 @@ +import numpy as np +import torch +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import roc_auc_score + +from elk.metrics import roc_auc + + +def test_roc_auc_score(): + # Generate 1D binary classification dataset + X_1d, y_true_1d = make_classification(n_samples=1000, random_state=42) + + # Generate 2D matrix of binary classification datasets + X_2d_1, y_true_2d_1 = make_classification(n_samples=1000, random_state=43) + X_2d_2, y_true_2d_2 = make_classification(n_samples=1000, random_state=44) + + # Fit LR models and get predicted probabilities for 1D and 2D cases + lr_1d = LogisticRegression(random_state=42).fit(X_1d, y_true_1d) + y_scores_1d = lr_1d.predict_proba(X_1d)[:, 1] + + lr_2d_1 = LogisticRegression(random_state=42).fit(X_2d_1, y_true_2d_1) + y_scores_2d_1 = lr_2d_1.predict_proba(X_2d_1)[:, 1] + + lr_2d_2 = LogisticRegression(random_state=42).fit(X_2d_2, y_true_2d_2) + y_scores_2d_2 = lr_2d_2.predict_proba(X_2d_2)[:, 1] + + # Stack the datasets into 2D matrices + y_true_2d = np.vstack((y_true_2d_1, y_true_2d_2)) + y_scores_2d = np.vstack((y_scores_2d_1, y_scores_2d_2)) + + # Convert to PyTorch tensors + y_true_1d_torch = torch.tensor(y_true_1d) + y_scores_1d_torch = torch.tensor(y_scores_1d) + y_true_2d_torch = torch.tensor(y_true_2d) + y_scores_2d_torch = torch.tensor(y_scores_2d) + + # Calculate ROC AUC score using batch_roc_auc_score function for 1D and 2D cases + roc_auc_1d_torch = roc_auc(y_true_1d_torch, y_scores_1d_torch).item() + roc_auc_2d_torch = roc_auc(y_true_2d_torch, y_scores_2d_torch).numpy() + + # Calculate ROC AUC score with sklearn's roc_auc_score function for 1D and 2D cases + roc_auc_1d_sklearn = roc_auc_score(y_true_1d, y_scores_1d) + roc_auc_2d_sklearn = np.array( + [ + roc_auc_score(y_true_2d_1, y_scores_2d_1), + roc_auc_score(y_true_2d_2, y_scores_2d_2), + ] + ) + + # Assert that the results from the two implementations are almost equal + np.testing.assert_almost_equal(roc_auc_1d_torch, roc_auc_1d_sklearn) + np.testing.assert_almost_equal(roc_auc_2d_torch, roc_auc_2d_sklearn) From a7f1ea044d5b16c5fbd9f68575707c7f0f524de5 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 15 Apr 2023 22:15:00 +0000 Subject: [PATCH 41/43] Fix CCS smoke test failure --- elk/training/ccs_reporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index e655e195..b839e3fd 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -175,7 +175,7 @@ def check_separability( pseudo_preds = pseudo_clf( # b v d -> (b v) d torch.cat([val_x0, val_x1]).flatten(0, 1) - ) + ).squeeze(-1) return roc_auc(pseudo_val_labels, pseudo_preds).item() def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: From 3abeb6023f7ed187173234bdc82d8dc851ab65e5 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Sun, 16 Apr 2023 10:56:58 -0700 Subject: [PATCH 42/43] Update extraction.py remove typo --- elk/extraction/extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 17bacb87..44c72ccd 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -113,7 +113,7 @@ def extract_hiddens( stream=p_cfg.stream, rank=rank, world_size=world_size, - ) # this dataset is already sharded, buqt hasn't been truncated to max_examples + ) # this dataset is already sharded, bug hasn't been truncated to max_examples model = instantiate_model( cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32 From 4c60061ca5a1eaa6d1d0e553e5bdffc99276203b Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Sun, 16 Apr 2023 14:34:01 -0700 Subject: [PATCH 43/43] Update extraction.py fix typo --- elk/extraction/extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 44c72ccd..2f9bda09 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -113,7 +113,7 @@ def extract_hiddens( stream=p_cfg.stream, rank=rank, world_size=world_size, - ) # this dataset is already sharded, bug hasn't been truncated to max_examples + ) # this dataset is already sharded, but hasn't been truncated to max_examples model = instantiate_model( cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32