diff --git a/elk/__init__.py b/elk/__init__.py index ce69da0dc..f6b132016 100644 --- a/elk/__init__.py +++ b/elk/__init__.py @@ -1,11 +1,9 @@ -from .extraction import Extract, extract_hiddens -from .training import EigenFitter, EigenFitterConfig -from .truncated_eigh import truncated_eigh +from .evaluation import Eval +from .extraction import Extract +from .training.train import Elicit __all__ = [ - "EigenFitter", - "EigenFitterConfig", - "extract_hiddens", "Extract", - "truncated_eigh", + "Elicit", + "Eval", ] diff --git a/elk/debug_logging.py b/elk/debug_logging.py index 59bea62fd..2592e129b 100644 --- a/elk/debug_logging.py +++ b/elk/debug_logging.py @@ -31,26 +31,24 @@ def save_debug_log(datasets: list[DatasetDictWithName], out_dir: Path) -> None: else: train_split, val_split = select_train_val_splits(ds) - text_questions = ds[val_split][0]["text_questions"] + if len(ds[val_split]) == 0: + logging.warning(f"Val split '{val_split}' is empty!") + continue + + texts = ds[val_split][0]["texts"] template_ids = ds[val_split][0]["variant_ids"] - label = ds[val_split][0]["label"] + ds[val_split][0]["label"] # log the train size and val size if train_split is not None: logging.info(f"Train size: {len(ds[train_split])}") logging.info(f"Val size: {len(ds[val_split])}") - templates_text = f"{len(text_questions)} templates used:\n" + templates_text = f"{len(texts)} templates used:\n" trailing_whitespace = False - for (text0, text1), id in zip(text_questions, template_ids): - templates_text += ( - f'***---TEMPLATE "{id}"---***\n' - f"{'false' if label else 'true'}:\n" - f'"""{text0}"""\n' - f"{'true' if label else 'false'}:\n" - f'"""{text1}"""\n\n\n' - ) - if text0[-1].isspace() or text1[-1].isspace(): + for text, id in zip(texts, template_ids): + templates_text += f'***---TEMPLATE "{id}"---***\n' f'"""{text}"""\n' + if text[-1].isspace(): trailing_whitespace = True if trailing_whitespace: logging.warning( diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 4e6ec59f3..cfc422d4d 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -7,7 +7,7 @@ from simple_parsing.helpers import field from ..files import elk_reporter_dir -from ..metrics import evaluate_preds +from ..metrics import evaluate_preds, get_logprobs from ..run import Run from ..utils import Color @@ -17,7 +17,6 @@ class Eval(Run): """Full specification of a reporter evaluation run.""" source: Path = field(positional=True) - skip_supervised: bool = False def __post_init__(self): # Set our output directory before super().execute() does @@ -31,55 +30,68 @@ def execute(self, highlight_color: Color = "cyan"): @torch.inference_mode() def apply_to_layer( self, layer: int, devices: list[str], world_size: int - ) -> dict[str, pd.DataFrame]: + ) -> tuple[dict[str, pd.DataFrame], dict]: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) val_output = self.prepare_data(device, layer, "val") experiment_dir = elk_reporter_dir() / self.source - reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt" - reporter = torch.load(reporter_path, map_location=device) + lr_dir = experiment_dir / "lr_models" + with open(lr_dir / f"layer_{layer}.pt", "rb") as f: + lr_models = torch.load(f, map_location=device) + if not isinstance(lr_models, list): # backward compatibility + lr_models = [lr_models] + out_logprobs = defaultdict(dict) row_bufs = defaultdict(list) - for ds_name, (val_h, val_gt, val_lm_preds) in val_output.items(): + for ds_name, val_data in val_output.items(): meta = {"dataset": ds_name, "layer": layer} - val_credences = reporter(val_h) - for mode in ("none", "partial", "full"): - row_bufs["eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_credences, mode).to_dict(), - } + if self.save_logprobs: + out_logprobs[ds_name] = dict( + row_ids=val_data.row_ids.cpu(), + variant_ids=val_data.variant_ids, + texts=val_data.texts, + labels=val_data.labels.cpu(), + lm=dict(), + lr=dict(), ) - - if val_lm_preds is not None: + for mode in ("none", "full"): + if val_data.lm_log_odds is not None: + if self.save_logprobs: + out_logprobs[ds_name]["lm"][mode] = get_logprobs( + val_data.lm_log_odds, mode + ).cpu() row_bufs["lm_eval"].append( { - **meta, "ensembling": mode, - **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), + **meta, + **evaluate_preds( + val_data.labels, val_data.lm_log_odds, mode + ).to_dict(), } ) - lr_dir = experiment_dir / "lr_models" - if not self.skip_supervised and lr_dir.exists(): - with open(lr_dir / f"layer_{layer}.pt", "rb") as f: - lr_models = torch.load(f, map_location=device) - if not isinstance(lr_models, list): # backward compatibility - lr_models = [lr_models] - - for i, model in enumerate(lr_models): - model.eval() - row_bufs["lr_eval"].append( - { - "ensembling": mode, - "inlp_iter": i, - **meta, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), - } - ) + if self.save_logprobs: + out_logprobs[ds_name]["lr"][mode] = dict() + + for i, model in enumerate(lr_models): + model.eval() + val_log_odds = model(val_data.hiddens) + if self.save_logprobs: + out_logprobs[ds_name]["lr"][mode][i] = get_logprobs( + val_log_odds, mode + ).cpu() + row_bufs["lr_eval"].append( + { + "ensembling": mode, + "inlp_iter": i, + **meta, + **evaluate_preds( + val_data.labels, val_log_odds, mode + ).to_dict(), + } + ) - return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs diff --git a/elk/extraction/__init__.py b/elk/extraction/__init__.py index 3fab2e320..e48c77816 100644 --- a/elk/extraction/__init__.py +++ b/elk/extraction/__init__.py @@ -1,15 +1,18 @@ from .balanced_sampler import BalancedSampler, FewShotSampler -from .extraction import Extract, extract, extract_hiddens +from .extraction import Extract, extract, tokenize_dataset from .generator import _GeneratorBuilder, _GeneratorConfig -from .prompt_loading import load_prompts +from .inference_server import InferenceServer +from .prompt_loading import get_prompter, load_prompts __all__ = [ "BalancedSampler", "FewShotSampler", "Extract", - "extract_hiddens", + "InferenceServer", "extract", "_GeneratorConfig", "_GeneratorBuilder", "load_prompts", + "get_prompter", + "tokenize_dataset", ] diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 2975823d8..f4e83fa1b 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,16 +1,14 @@ """Functions for extracting the hidden states of a model.""" -import logging import os -from contextlib import nullcontext, redirect_stdout +from collections import defaultdict from dataclasses import InitVar, dataclass, replace from itertools import zip_longest from typing import Any, Iterable, Literal -from warnings import filterwarnings import torch from datasets import ( Array2D, - Array3D, + Dataset, DatasetDict, DatasetInfo, DownloadMode, @@ -23,30 +21,26 @@ ) from simple_parsing import Serializable, field from torch import Tensor -from transformers import AutoConfig, PreTrainedModel +from transformers import AutoConfig -from ..promptsource import DatasetTemplates from ..utils import ( Color, assert_type, colorize, float_to_int16, - infer_label_column, - infer_num_classes, - instantiate_model, instantiate_tokenizer, - is_autoregressive, prevent_name_conflicts, select_split, select_train_val_splits, - select_usable_devices, ) +from ..utils.hf_utils import is_autoregressive from .dataset_name import ( DatasetDictWithName, parse_dataset_string, ) from .generator import _GeneratorBuilder -from .prompt_loading import load_prompts +from .inference_server import InferenceServer +from .prompt_loading import get_prompter, load_prompts @dataclass @@ -62,12 +56,15 @@ class Extract(Serializable): data_dirs: tuple[str, ...] = () """Directory to use for caching the hiddens. Defaults to `HF_DATASETS_CACHE`.""" - binarize: bool = False - """Whether to binarize the dataset labels for multi-class datasets.""" + get_lm_preds: bool = True + """Whether to extract the LM predictions.""" int8: bool = False """Whether to perform inference in mixed int8 precision with `bitsandbytes`.""" + fsdp: bool = False + """Whether to use FullyShardedDataParallel for inference.""" + max_examples: tuple[int, int] = (1000, 1000) """Maximum number of examples to use from each split of the dataset.""" @@ -78,6 +75,13 @@ class Extract(Serializable): """The number of prompt templates to use for each example. If -1, all available templates are used.""" + balance: bool = True + """Whether to balance the number of examples per class.""" + + statement_column: str | None = None + """Name of the column containing the model input strings when using a built-in + prompt template. If None, we use the "statement" column.""" + layers: tuple[int, ...] = () """Indices of layers to extract hidden states from. We follow the HF convention, so 0 is the embedding, and 1 is the output of the first transformer layer.""" @@ -91,13 +95,9 @@ class Extract(Serializable): template_path: str | None = None """Path to pass into `DatasetTemplates`. By default we use the dataset name.""" - token_loc: Literal["first", "last", "mean"] = "last" + token_loc: Literal["first", "last", "penultimate", "mean"] = "last" """The location of the token to extract hidden states from.""" - use_encoder_states: bool = False - """Whether to extract hidden states from the encoder instead of the decoder in the - case of encoder-decoder models.""" - def __post_init__(self, layer_stride: int): if self.num_variants != -1: print("WARNING: num_variants is deprecated; use prompt_indices instead.") @@ -146,200 +146,114 @@ def explode(self) -> list["Extract"]: ] -@torch.inference_mode() -def extract_hiddens( +def tokenize_dataset( cfg: "Extract", - *, - device: str | torch.device = "cpu", split_type: Literal["train", "val"] = "train", - rank: int = 0, - world_size: int = 1, -) -> Iterable[dict]: - """Run inference on a model with a set of prompts, yielding the hidden states.""" +) -> Dataset: + """Apply the prompt templates to the dataset and return the tokenized LM inputs. + Each dict contains the keys `input_ids`, `variant_id`, + `row_id`, `text`, and `label`. If lm_preds is True, we also include `answer_ids` + and `num_suffix_tokens`. + """ os.environ["TOKENIZERS_PARALLELISM"] = "false" - # Silence datasets logging messages from all but the first process - if rank != 0: - filterwarnings("ignore") - logging.disable(logging.CRITICAL) - ds_names = cfg.datasets assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." - # We use contextlib.redirect_stdout to prevent `bitsandbytes` from printing its - # welcome message on every rank - with redirect_stdout(None) if rank != 0 else nullcontext(): - model = instantiate_model(cfg.model, device=device, load_in_8bit=cfg.int8) - tokenizer = instantiate_tokenizer( - cfg.model, truncation_side="left", verbose=rank == 0 - ) - - is_enc_dec = model.config.is_encoder_decoder - if is_enc_dec and cfg.use_encoder_states: - assert hasattr(model, "get_encoder") and callable(model.get_encoder) - model = assert_type(PreTrainedModel, model.get_encoder()) - is_enc_dec = False + tokenizer = instantiate_tokenizer(cfg.model, truncation_side="left") - has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states) - if has_lm_preds and rank == 0: - print("Model has language model head, will store predictions.") + # TODO: support using the encoder only of an encoder-decoder model prompt_ds = load_prompts( ds_names[0], - binarize=cfg.binarize, num_shots=cfg.num_shots, split_type=split_type, template_path=cfg.template_path, - rank=rank, - world_size=world_size, + include_answers=cfg.get_lm_preds, + balance=cfg.balance, seed=cfg.seed, + statement_column=cfg.statement_column, ) - # Add one to the number of layers to account for the embedding layer - layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers + 1)) - - global_max_examples = cfg.max_examples[0 if split_type == "train" else 1] + max_examples = cfg.max_examples[0 if split_type == "train" else 1] - # break `max_examples` among the processes roughly equally - max_examples = global_max_examples // world_size max_length = assert_type(int, tokenizer.model_max_length) - # Keep track of the number of examples we've yielded so far. We can't do something - # clean like `islice` the dataset, because we skip examples that are too long, and - # we can't predict how many of those there will be. - num_yielded = 0 - - # the last process gets the remainder (which is usually small) - if rank == world_size - 1: - max_examples += global_max_examples % world_size - + out_records = [] for example in prompt_ds: + num_variants = len(example["template_names"]) + # Check if we've yielded enough examples - if num_yielded >= max_examples: + if len(out_records) >= max_examples * num_variants: break - num_variants = len(example["prompts"]) - num_choices = len(example["prompts"][0]) + # Throw out all variants if any of them are too long + any_too_long = False + record_variants = [] - hidden_dict = { - f"hidden_{layer_idx}": torch.empty( - num_variants, - num_choices, - model.config.hidden_size, - device=device, - dtype=torch.int16, + # Iterate over variants + for i, statement in enumerate(example["statements"]): + if cfg.get_lm_preds: + suffix = example["suffixes"][i] + answer_choices = example["answer_choices"][i] + assert len(answer_choices) == 2 + answer_ids = [] + for choice in answer_choices: + a_id = tokenizer.encode(choice, add_special_tokens=False) + if len(a_id) > 1: + print( + f"WARNING: answer choice '{choice}' is more than one " + "token, LM probabilities will be calculated using the " + "first token only." + ) + answer_ids.append(a_id[0]) + else: + suffix = "" + + suffix_tokens = torch.tensor( + tokenizer.encode(suffix, add_special_tokens=False), + dtype=torch.long, ) - for layer_idx in layer_indices - } - lm_logits = torch.empty( - num_variants, - num_choices, - device=device, - dtype=torch.float32, - ) - text_questions = [] - # Iterate over variants - for i, record in enumerate(example["prompts"]): - variant_questions = [] - - # Iterate over answers - for j, choice in enumerate(record): - text = choice["question"] - - # Only feed question, not the answer, to the encoder for enc-dec models - target = choice["answer"] if is_enc_dec else None - encoding = tokenizer( - text, - # Keep [CLS] and [SEP] for BERT-style models - add_special_tokens=True, - return_tensors="pt", - text_target=target, # type: ignore[arg-type] - ).to(device) - - ids = assert_type(Tensor, encoding.input_ids) - if is_enc_dec: - answer = labels = assert_type(Tensor, encoding.labels) - else: - encoding2 = tokenizer( - choice["answer"], - # Don't include [CLS] and [SEP] in the answer - add_special_tokens=False, - return_tensors="pt", - ).to(device) - - answer = assert_type(Tensor, encoding2.input_ids) - labels = ( - # -100 is the mask token - torch.cat([torch.full_like(ids, -100), answer], dim=-1) - if has_lm_preds - else None - ) - ids = torch.cat([ids, answer], -1) - - # If this input is too long, skip it - if ids.shape[-1] > max_length: - break - else: - # Record the EXACT question we fed to the model - variant_questions.append(text) - - inputs: dict[str, Tensor | None] = dict(input_ids=ids.long()) - if is_enc_dec or has_lm_preds: - inputs["labels"] = labels - outputs = model(**inputs, output_hidden_states=True) - - # Compute the log probability of the answer tokens if available - if has_lm_preds: - lm_logits[i, j] = -assert_type(Tensor, outputs.loss) - - hiddens = ( - outputs.get("decoder_hidden_states") or outputs["hidden_states"] - ) - # Throw out layers we don't care about - hiddens = [hiddens[i] for i in layer_indices] - - # Current shape of each element: (batch_size, seq_len, hidden_size) - if cfg.token_loc == "first": - hiddens = [h[..., 0, :] for h in hiddens] - elif cfg.token_loc == "last": - hiddens = [h[..., -1, :] for h in hiddens] - elif cfg.token_loc == "mean": - hiddens = [h.mean(dim=-2) for h in hiddens] - else: - raise ValueError(f"Invalid token_loc: {cfg.token_loc}") - - for layer_idx, hidden in zip(layer_indices, hiddens): - hidden_dict[f"hidden_{layer_idx}"][i, j] = float_to_int16(hidden) - - # We skipped a pseudolabel because it was too long; break out of this whole - # example and move on to the next one - if len(variant_questions) != num_choices: - break + encoding = tokenizer( + statement, + # Keep [CLS] and [SEP] for BERT-style models + add_special_tokens=True, + return_tensors="pt", + ) - # Usual case: we have the expected number of pseudolabels - text_questions.append(variant_questions) + # suffix comes right after the last statement token, before the answer + ids = torch.cat([encoding.input_ids, suffix_tokens.unsqueeze(0)], dim=-1) - # We skipped a variant because it was too long; move on to the next example - if len(text_questions) != num_variants: - continue + # If this input is too long, skip it + if ids.shape[-1] > max_length: + any_too_long = True + break - out_record: dict[str, Any] = dict( - label=example["label"], - variant_ids=example["template_names"], - text_questions=text_questions, - **hidden_dict, - ) - if has_lm_preds: - out_record["model_logits"] = lm_logits + out_record: dict[str, Any] = dict( + row_id=example["row_id"], + variant_id=example["template_names"][i], + label=example["label"], + text=statement + suffix, + input_ids=ids.long(), + ) + if cfg.get_lm_preds: + out_record["answer_ids"] = answer_ids # type: ignore + # keep track of where to extract hiddens from + out_record["num_suffix_tokens"] = len(suffix_tokens) + record_variants.append(out_record) - num_yielded += 1 - yield out_record + if any_too_long: + continue + # print an example text to stdout + if len(out_records) == 0: + print(f"Example text: {record_variants[0]['text']}") + out_records.extend(record_variants) -# Dataset.from_generator wraps all the arguments in lists, so we unpack them here -def _extraction_worker(**kwargs): - yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) + # transpose the list of dicts into a dict of lists + out_records = {k: [d[k] for d in out_records] for k in out_records[0]} + return Dataset.from_dict(out_records) def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]: @@ -350,51 +264,39 @@ def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]: ds_name, config_name = parse_dataset_string(dataset_config_str=cfg.datasets[0]) info = get_dataset_config_info(ds_name, config_name or None) - if not cfg.template_path: - prompter = DatasetTemplates(ds_name, config_name) - else: - prompter = DatasetTemplates(cfg.template_path) - - ds_features = assert_type(Features, info.features) - label_col = prompter.label_column or infer_label_column(ds_features) - num_classes = ( - 2 - if cfg.binarize or prompter.binarize - else infer_num_classes(ds_features[label_col]) - ) + assert_type(Features, info.features) - num_dropped = prompter.drop_non_mc_templates() + prompter, _ = get_prompter(ds_name, config_name, cfg.template_path) + + # num_dropped = prompter.drop_non_mc_templates() num_variants = len(prompter.templates) - if num_dropped: - print(f"Dropping {num_dropped} non-multiple choice templates") + # if num_dropped: + # print(f"Dropping {num_dropped} non-multiple choice templates") layer_cols = { - f"hidden_{layer}": Array3D( + f"hidden_{layer}": Array2D( dtype="int16", - shape=(num_variants, num_classes, model_cfg.hidden_size), + shape=(num_variants, model_cfg.hidden_size), ) # Add 1 to include the embedding layer for layer in cfg.layers or range(model_cfg.num_hidden_layers + 1) } other_cols = { + "row_id": Value(dtype="int64"), "variant_ids": Sequence( Value(dtype="string"), length=num_variants, ), "label": Value(dtype="int64"), - "text_questions": Sequence( - Sequence( - Value(dtype="string"), - ), + "texts": Sequence( + Value(dtype="string"), length=num_variants, ), } - - # Only add model_logits if the model is an autoregressive model - if is_autoregressive(model_cfg, not cfg.use_encoder_states): - other_cols["model_logits"] = Array2D( - shape=(num_variants, num_classes), - dtype="float32", + if cfg.get_lm_preds: + other_cols["lm_log_odds"] = Sequence( + Value(dtype="float32"), + length=num_variants, ) return info, Features({**layer_cols, **other_cols}) @@ -406,13 +308,15 @@ def extract( disable_cache: bool = False, highlight_color: Color = "cyan", num_gpus: int = -1, - min_gpu_mem: int | None = None, split_type: Literal["train", "val", None] = None, ) -> DatasetDictWithName: """Extract hidden states from a model and return a `DatasetDict` containing them.""" + info, features = hidden_features(cfg) - devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem) + model_config = AutoConfig.from_pretrained(cfg.model) + if not is_autoregressive(model_config, include_enc_dec=True) and cfg.get_lm_preds: + raise ValueError("Can only extract LM predictions from autoregressive models.") limits = cfg.max_examples splits = assert_type(SplitDict, info.splits) @@ -435,6 +339,94 @@ def extract( else: print(f"{pretty_name} using '{split_name}' for validation") + def select_hiddens(outputs: Any, **kwargs: Any) -> tuple[dict[str, Tensor], Tensor]: + tok_loc_offset = kwargs.get("num_suffix_tokens", 0) + # Add one to the number of layers to account for the embedding layer + layer_indices = cfg.layers or tuple(range(model_config.num_hidden_layers + 1)) + + hiddens = outputs.get("decoder_hidden_states") or outputs["hidden_states"] + # Throw out layers we don't care about + hiddens = [hiddens[i] for i in layer_indices] + + # Current shape of each element: (batch_size, seq_len, hidden_size) + if cfg.token_loc == "first": + hiddens = [h[..., 0, :] for h in hiddens] + elif cfg.token_loc == "last": + hiddens = [h[..., h.shape[-2] - tok_loc_offset - 1, :] for h in hiddens] + elif cfg.token_loc == "penultimate": + hiddens = [h[..., h.shape[-2] - tok_loc_offset - 2, :] for h in hiddens] + elif cfg.token_loc == "mean": + hiddens = [h[..., :-tok_loc_offset, :].mean(dim=-2) for h in hiddens] + else: + raise ValueError(f"Invalid token_loc: {cfg.token_loc}") + + hidden_dict = dict() + for layer_idx, hidden in zip(layer_indices, hiddens): + hidden_dict[f"hidden_{layer_idx}"] = float_to_int16(hidden.flatten()).cpu() + + if (answer_ids := kwargs.get("answer_ids")) is not None: + # log_odds = log(p(yes)/(p(no)) = log(p(yes)) - log(p(no)) + logits = outputs["logits"][0, -1, answer_ids] + logprobs = logits.log_softmax(dim=-1) + lm_log_odds = logprobs[1] - logprobs[0] + else: + lm_log_odds = torch.Tensor([torch.nan]) + + return hidden_dict, lm_log_odds + + def extract_hiddens( + cfg: Extract, + split_type: Literal["train", "val"], + server: InferenceServer, + ) -> Iterable[dict]: + encodings = tokenize_dataset(cfg, split_type=split_type) + num_variants = len(encodings.unique("variant_id")) + + if not server.running: + server.start() + encodings = encodings.add_column("id", range(len(encodings))) # type: ignore + + buffer = defaultdict(list) # row_id -> list of dicts + for idx, (hidden_dict, lm_log_odds) in server.imap( + select_hiddens, + encodings, + use_tqdm=False, + model_kwargs=dict(output_hidden_states=True), + ): + encoding = encodings[idx] + row_id = encoding["row_id"] + buffer[row_id].append( + dict(lm_log_odds=lm_log_odds, **encoding, **hidden_dict) + ) + if len(buffer[row_id]) == num_variants: + # we have a complete example + ex = buffer[row_id] + ex = sorted(ex, key=lambda d: d["variant_id"]) + assert all(d["label"] == ex[0]["label"] for d in ex) + assert len(set(d["variant_id"] for d in ex)) == num_variants + out_record: dict[str, Any] = dict( + variant_ids=[d["variant_id"] for d in ex], + label=ex[0]["label"], + row_id=ex[0]["row_id"], + texts=[d["text"] for d in ex], + **{k: torch.stack([d[k] for d in ex]) for k in hidden_dict}, + ) + if cfg.get_lm_preds: + out_record["lm_log_odds"] = torch.stack( + [d["lm_log_odds"] for d in ex] # type: ignore + ) + del buffer[row_id] + yield out_record + + # hf wraps everything in a list here, so we unpack them here + def _extraction_worker(**kwargs): + yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) + + # TODO: support int8 + server = InferenceServer( + model_str=cfg.model, num_workers=num_gpus, cpu_offload=True, fsdp=cfg.fsdp + ) + builders = { split_name: _GeneratorBuilder( cache_dir=None, @@ -447,28 +439,27 @@ def extract( dataset_name=v.dataset_name, ), gen_kwargs=dict( - cfg=[cfg] * len(devices), - device=devices, - rank=list(range(len(devices))), - split_type=[ty] * len(devices), - world_size=[len(devices)] * len(devices), + cfg=[cfg], + split_type=[ty], + server=[server], ), ) for limit, (split_name, v), ty in zip(limits, splits.items(), split_types) } - import multiprocess as mp - - mp.set_start_method("spawn", force=True) # type: ignore[attr-defined] ds = dict() for split, builder in builders.items(): builder.download_and_prepare( download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None, - num_proc=len(devices), + num_proc=None, ) - ds[split] = builder.as_dataset(split=split) + ds[split] = builder.as_dataset(split=split) # type: ignore[assignment] + + if server.running: + server.shutdown() dataset_dict = DatasetDict(ds) + return DatasetDictWithName( name=cfg.datasets[0], dataset=dataset_dict, diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 84818c832..6f6754e36 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -30,7 +30,7 @@ def create_config_id( config_kwargs["gen_kwargs"] = { k: v[0] for k, v in config_kwargs.get("gen_kwargs", {}).items() - if k not in ("device", "rank", "world_size") + if k not in ("device", "server", "fsdp") } return super().create_config_id(config_kwargs, custom_features) diff --git a/elk/extraction/inference_server.py b/elk/extraction/inference_server.py new file mode 100644 index 000000000..9c39ec411 --- /dev/null +++ b/elk/extraction/inference_server.py @@ -0,0 +1,380 @@ +import inspect +import logging +import multiprocessing as std_mp +import os +import socket +import warnings +from dataclasses import dataclass +from functools import partial +from itertools import cycle +from typing import Any, Callable, Iterable, Type, cast + +import dill +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from datasets import Dataset +from torch.distributed.fsdp import CPUOffload +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from tqdm import tqdm +from transformers import PreTrainedModel +from transformers.modeling_outputs import ModelOutput + +from elk.utils import instantiate_model, pytree_map, select_usable_devices + + +@dataclass(frozen=True) +class _Sentinel: + """Sentinel value used to indicate that a worker is done.""" + + +SENTINEL = _Sentinel() + + +@dataclass +class InferenceServer: + """High-level interface for running inference on a model on multiple GPUs. + + This is basically a glorified `multiprocessing.Pool`. The only difference is that + each worker maintains a copy of the model on a dedicated GPU. + """ + + model_str: str + num_workers: int = -1 + cpu_offload: bool = False + fsdp: bool = False + + def __post_init__(self): + self._current_id = 0 + self._process_ctx: mp.ProcessContext | None = None + + self._result_queues = [] + self._task_queues = [] + + @property + def running(self) -> bool: + """Whether the server is running.""" + return self._process_ctx is not None + + def start(self) -> None: + """Spin up the workers.""" + if self._process_ctx is not None: + raise RuntimeError("The server is already running") + + # Load the model on the main process, then zero-copy share it with the workers. + # This ensures that we don't copy the model num_workers times on the CPU and + # run out of RAM for large models + print("Loading model...") + model = instantiate_model(self.model_str, torch_dtype="auto") + model.share_memory() + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) + + # Determine which GPUs we can use + devices = select_usable_devices( + self.num_workers, min_memory=model_size if not self.fsdp else None + ) + self.num_workers = len(devices) # This may have been -1 before + + fsdp_port, wrap_policy = None, None + if self.fsdp: + fsdp_port = find_available_port() + msg = f"Fully Sharded Data Parallel running on port {fsdp_port}" + + layer_cls = get_transformer_layer_cls(model) + if layer_cls is not None: + msg += f" with '{layer_cls.__name__}' wrapping policy" + wrap_policy = partial( + transformer_auto_wrap_policy, transformer_layer_cls={layer_cls} + ) + + print(msg) + + self._manager = mp.Manager() + self._result_queues = [self._manager.Queue() for _ in range(self.num_workers)] + self._task_queues = [self._manager.Queue() for _ in range(self.num_workers)] + self._process_ctx = mp.spawn( + _worker_wrapper, + args=( + devices, + model, + self._task_queues, + self._result_queues, + self.cpu_offload, + fsdp_port, + wrap_policy, + ), + join=False, + nprocs=self.num_workers, + ) + + def shutdown(self) -> bool: + """Shut down all the workers, returning `True` if successful.""" + if self._process_ctx is None: + raise RuntimeError("Can't shut down a server that isn't running") + + # Let the workers know that they should shut down + for q in self._task_queues: + try: + q.put_nowait(None) + except std_mp.queues.Empty: # type: ignore[attr-defined] + pass + + self._manager.shutdown() + return self._process_ctx.join() + + # Support use as a context manager, just like mp.Pool + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown() + + def map_forward( + self, + dataset: Dataset, + model_kwargs: dict[str, Any] | None = None, + use_tqdm: bool = False, + ) -> list: + """Maps the model's `forward` method over the given dataset, without + running a closure on the outputs.""" + return self.map( + lambda x: x, dataset, model_kwargs=model_kwargs, use_tqdm=use_tqdm + ) + + def imap_forward( + self, + dataset: Dataset, + model_kwargs: dict[str, Any] | None = None, + use_tqdm: bool = False, + ) -> Iterable: + """Maps the model's `forward` method over the given dataset, without + running a closure on the outputs.""" + yield from self.imap( + lambda x: x, dataset, model_kwargs=model_kwargs, use_tqdm=use_tqdm + ) + + def map( + self, + closure: Callable[[ModelOutput], Any], + dataset: Dataset, + model_kwargs: dict[str, Any] | None = None, + use_tqdm: bool = False, + ) -> list: + """Run inference on the given inputs, running a closure on the outputs. + Dataset must contain an `input_ids` column, and optionally other arguments + that the model expects.""" + # add id column to dataset if not present to keep track of order + if "id" not in dataset.column_names: + dataset = dataset.add_column("id", range(len(dataset))) # type: ignore + ids = dataset["id"] + output_tuples = list(self.imap(closure, dataset, model_kwargs, use_tqdm)) + outputs = dict(output_tuples) + return [outputs[id] for id in ids] + + def imap( + self, + closure: Callable[[ModelOutput], Any] | None, + dataset: Dataset, + model_kwargs: dict[str, Any] | None = None, + use_tqdm: bool = False, + ) -> Iterable: + """Run inference on the given inputs, running a closure on the outputs. + Dataset must contain an `input_ids` column, and optionally other arguments + that the model expects. `dataset` is also required to have an `id` column, + because the outputs are not guaranteed to be returned in the same order as + the inputs. + + yields: (id, outputs)""" + if self._process_ctx is None: + raise RuntimeError("Can't run inference on a server that isn't running") + + assert "id" in dataset.column_names, "Dataset must contain an 'id' column" + if len(dataset) % self.num_workers != 0: + # server requires that the dataset's length is a multiple of the world size + assert self.num_workers != -1 + + # duplicate some rows + num_rows = len(dataset) + num_needed = self.num_workers - (num_rows % self.num_workers) + dummy = dataset[0] + dummy_id = dummy["id"] + for _ in range(num_needed): + dataset = dataset.add_item(dummy) # type: ignore + else: + dummy_id = -1 + + # We need PyTorch tensors + dataset = dataset.with_format("torch") + + # Pickle the closure and send it to the workers + closure_pkl = dill.dumps(closure) + model_kwargs_pkl = dill.dumps(model_kwargs or {}) + shards = [dataset.shard(self.num_workers, i) for i in range(self.num_workers)] + for q, shard in zip(self._task_queues, shards): + q.put((closure_pkl, model_kwargs_pkl, shard)) + + generator = round_robin(self._result_queues) # type: ignore[arg-type] + seen_dummy = False + for out in tqdm(generator, total=len(dataset), disable=not use_tqdm): + if out[0] == dummy_id: + if seen_dummy: + continue # ignore any extra dummy rows + else: + seen_dummy = True + yield out + + +def get_transformer_layer_cls(model: torch.nn.Module) -> Type[torch.nn.Module] | None: + """Get the class of the transformer layer used by the given model.""" + total_params = sum(p.numel() for p in model.parameters()) + for module in model.modules(): + if isinstance(module, torch.nn.ModuleList): + module_params = sum(p.numel() for p in module.parameters()) + if module_params > total_params / 2: + return type(module[0]) + + return None + + +def get_socket_with_port() -> socket.socket: + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + for addr in addrs: + family, type, proto, _, _ = addr + s = socket.socket(family, type, proto) + try: + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError: + s.close() + + raise RuntimeError("Failed to create a socket") + + +def find_available_port() -> int: + s = get_socket_with_port() + _, port, *_ = s.getsockname() + s.close() + + return port + + +def round_robin(queues: list[mp.Queue]) -> Iterable[Any]: + """Yield items from the given queues in round-robin order.""" + exhausted = set() + + for idx, q in cycle(enumerate(queues)): + if len(exhausted) == len(queues): + break + if idx in exhausted: + continue + + try: + item = q.get(timeout=0.01) + except std_mp.queues.Empty: # type: ignore[attr-defined] + pass + else: + if item == SENTINEL: + exhausted.add(idx) + else: + yield item + + +@torch.inference_mode() +def _worker( + rank: int, + devices: list[str], + model: PreTrainedModel, + qs: list[mp.Queue], + out_qs: list[mp.Queue], + cpu_offload: bool = False, + fsdp_port: int | None = None, + wrap_policy: partial[bool] | None = None, +): + """Worker process that maintains a copy of the model on a dedicated GPU.""" + # Prevent duplicate logging messages + if rank != 0: + logging.disable(logging.CRITICAL) + warnings.filterwarnings("ignore") + + closure: Callable[[ModelOutput], Any] | None = None + dataset: Dataset | None = None + device = devices[rank] + + # Fully Sharded Data Parallel for large models + if fsdp_port is not None: + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(fsdp_port) + dist.init_process_group("nccl", rank=rank, world_size=len(devices)) + torch.cuda.set_device(device) + + wrapped = FSDP( + model, + auto_wrap_policy=wrap_policy, + cpu_offload=CPUOffload(offload_params=cpu_offload), + device_id=torch.device(device), + forward_prefetch=True, + ) + model = cast(PreTrainedModel, wrapped) + model_forward = model.module.forward # type: ignore[union-attr] + else: + model.to(device) # type: ignore[union-attr] + model_forward = model.forward + + param_names = set(inspect.signature(model_forward).parameters.keys()) + + # Breaks when x is the sentinel value indicating we should shut down + in_queue = qs[rank] + out_queue = out_qs[rank] + + while msg := in_queue.get(): + # Someone called map() giving us a new closure and dataset to use + assert isinstance(msg, tuple) and len(msg) == 3 + closure_pkl, model_kwargs_pkl, dataset = msg + closure = dill.loads(closure_pkl) + model_kwargs = dill.loads(model_kwargs_pkl) + + assert dataset is not None + for record in dataset: + assert isinstance(record, dict) + id = record["id"].item() + assert "input_ids" in record, "Dataset must contain an 'input_ids' column" + # Only pass the arguments that the model expects + input_record = {k: v for k, v in record.items() if k in param_names} + + def maybe_unsqueeze(v): + return v.unsqueeze(0) if v.ndim == 1 else v + + inputs_cuda = pytree_map( + lambda v: maybe_unsqueeze(v.to(device)), input_record + ) + # TODO: have model kwargs so we don't have to duplicate kwargs at each row + outputs = model(**inputs_cuda, **model_kwargs) + + if callable(closure): + outputs = closure(outputs, **record) + if outputs is not None: + # Move the outputs back to the CPU + outputs = pytree_map(lambda x: x.cpu().share_memory_(), outputs) + + # Send the outputs back to the main process + out_queue.put((id, outputs)) + + # Indicate we're done with this dataset + out_queue.put(SENTINEL) + + # Clean up the FSDP process group + if fsdp_port is not None: + dist.destroy_process_group() + + +def _worker_wrapper(rank: int, *args): + try: + return _worker(rank, *args) + except Exception as e: + print(f"Exception in worker {rank}: {e}") + raise e diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index cb42d2331..ce485f825 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -16,13 +16,13 @@ def load_prompts( ds_string: str, *, - binarize: bool = False, num_shots: int = 0, seed: int = 42, split_type: Literal["train", "val"] = "train", template_path: str | None = None, - rank: int = 0, - world_size: int = 1, + include_answers: bool = False, + balance: bool = True, + statement_column: str | None = None, ) -> Iterator[dict]: """Load a dataset full of prompts generated from the specified dataset. @@ -34,8 +34,7 @@ def load_prompts( seed: The seed to use for prompt randomization. split_type: Whether to use the train or val split of the dataset. template_path: Path to feed into `DatasetTemplates` for loading templates. - rank: The rank of the current process. Defaults to 0. - world_size: The number of processes. Defaults to 1. + statement_column: Name of the column to use for the statement text. Returns: An iterable of prompt dictionaries. @@ -45,23 +44,31 @@ def load_prompts( ds_dict = assert_type(dict, load_dataset(ds_name, config_name or None)) split_name = select_split(ds_dict, split_type) - ds = assert_type(Dataset, ds_dict[split_name].shuffle(seed=seed)) - if world_size > 1: - ds = ds.shard(world_size, rank) - - if template_path is None: - prompter = DatasetTemplates(ds_name, config_name) - else: - prompter = DatasetTemplates(template_path) + ds = assert_type(Dataset, ds_dict[split_name]) + if "row_id" not in ds.column_names: + ds = ds.add_column("row_id", range(len(ds))) # type: ignore + ds = ds.shuffle(seed=seed) + + prompter, using_blank = get_prompter(ds_name, config_name, template_path) + if using_blank: + print('Using blank template "{{ statement }}".') + statement_column = statement_column or "statement" + if statement_column not in ds.column_names: + raise ValueError( + f'Could not find statement column "{statement_column}".' + f" Please include the column or specify a different one with the" + f" `statement_column` argument." + ) + if statement_column != "statement": + ds = ds.rename_column(statement_column, "statement") - # If the prompt template says to binarize, we should - binarize = binarize or prompter.binarize - prompter.drop_non_mc_templates() + # TODO: allow for optionally using contrast pair templates so people + # don't have to rewrite them num_templates = len(prompter.templates) assert num_templates > 0 - if rank == 0: - print(f"Extracting {num_templates} variants of each prompt") + + print(f"Extracting {num_templates} variants of each prompt") label_column = prompter.label_column or infer_label_column(ds.features) @@ -74,8 +81,7 @@ def load_prompts( # Which classes are actually present in this split of the dataset? # This is shockingly fast since it uses an optimized Apache Arrow primitive. label_choices = sorted(ds.unique(label_column)) - if rank == 0: - print(f"Using the following pseudo-labels: {label_choices}") + print(f"Using the following pseudo-labels: {label_choices}") rng = Random(seed) if num_shots > 0: @@ -89,25 +95,24 @@ def load_prompts( else: fewshot_iter = None - if label_column in ds.features: + if label_column in ds.features and balance: ds = BalancedSampler( ds.to_iterable_dataset(), set(label_choices), label_col=label_column, ) else: - if rank == 0: + if balance: print("No label column found, not balancing") ds = ds.to_iterable_dataset() for example in ds: yield _convert_to_prompts( example, - binarize=binarize, label_column=label_column, label_choices=label_choices, # type: ignore[arg-type] prompter=prompter, - rng=rng, + include_answers=include_answers, fewshot_iter=fewshot_iter, ) @@ -115,63 +120,30 @@ def load_prompts( def _convert_to_prompts( example: dict[str, Any], prompter: DatasetTemplates, - binarize: bool, label_column: str, label_choices: list[bool | int | str], - rng: Random, + include_answers: bool = False, fewshot_iter: Iterator[list[dict]] | None = None, ) -> dict[str, Any]: """Prompt-generating function to pass to `IterableDataset.map`.""" - prompts = [] + statements = [] templates = list(prompter.templates.values()) - def qa_cat(q: str, a: str) -> str: - # if the jinja template already adds whitespace, don't add more - sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " " - return f"{q}{sep}{a}" if a and not a.isspace() else q - # For sanity checking that prompts are unique prompt_counter = Counter() label = example[label_column] - if binarize: - # Replace the full list of possibilities with a randomly sampled false label - # and the correct label, as done in the DLK paper. Note that this does add some - # "supervision" by stacking the deck in favor of the correct answer. - label_choices = [ - rng.choice([c for c in label_choices if c != label]), - label, - ] - rng.shuffle(label_choices) - for template in templates: - choices = [] - - for pseudo_label in label_choices: - fake_example = example.copy() - fake_example[label_column] = pseudo_label - - q, a = template.apply(fake_example) - prompt_counter[(q, a)] += 1 - - if fewshot_iter is not None: - # Infinite iterator so we don't need to worry about StopIteration - fewshot_examples = next(fewshot_iter) - fewshot_texts = [ - qa_cat(q, a) for q, a in map(template.apply, fewshot_examples) - ] - q = "\n\n".join(fewshot_texts) + "\n\n" + q - - choices.append( - dict( - # Strip whitespace from the answer to make it easier to - # compare with the model's output - answer=a.strip(), - question=q, - ) - ) + statement = template.apply(example) + prompt_counter[statement] += 1 - prompts.append(choices) + if fewshot_iter is not None: + # Infinite iterator so we don't need to worry about StopIteration + fewshot_examples = next(fewshot_iter) + fewshot_texts = list(map(template.apply, fewshot_examples)) + statement = "\n\n".join(fewshot_texts) + "\n\n" + statement + + statements.append(statement) # Sanity check: variants should be unique ((maybe_dup, dup_count),) = prompt_counter.most_common(1) @@ -181,8 +153,28 @@ def qa_cat(q: str, a: str) -> str: # Our reporter training and evaluation code assumes that the labels are integers. # If they're not, we need to convert them with index(). label_choices is guaranteed # to be sorted (see above). - return dict( + out_dict = dict( + row_id=example["row_id"], label=label_choices.index(label), - prompts=prompts, + statements=statements, template_names=[template.name for template in templates], ) + if include_answers: + out_dict.update( + answer_choices=[ + template.get_fixed_answer_choices_list() for template in templates + ], + suffixes=[template.suffix for template in templates], + ) + return out_dict + + +def get_prompter( + ds_name: str, config_name: str | None, template_path: str | None = None +) -> tuple[DatasetTemplates, bool]: + if template_path is None: + try: + return DatasetTemplates(ds_name, config_name), False + except ValueError: + return DatasetTemplates("_default"), True + return DatasetTemplates(template_path), template_path == "_default" diff --git a/elk/metrics/__init__.py b/elk/metrics/__init__.py index 7fb214501..25ed1b2a0 100644 --- a/elk/metrics/__init__.py +++ b/elk/metrics/__init__.py @@ -1,6 +1,6 @@ from .accuracy import accuracy_ci from .calibration import CalibrationError, CalibrationEstimate -from .eval import EvalResult, evaluate_preds, to_one_hot +from .eval import EvalResult, evaluate_preds, get_logprobs from .roc_auc import RocAucResult, roc_auc, roc_auc_ci __all__ = [ @@ -9,8 +9,8 @@ "CalibrationEstimate", "EvalResult", "evaluate_preds", + "get_logprobs", "roc_auc", "roc_auc_ci", - "to_one_hot", "RocAucResult", ] diff --git a/elk/metrics/eval.py b/elk/metrics/eval.py index 653beae55..8c837e8f9 100644 --- a/elk/metrics/eval.py +++ b/elk/metrics/eval.py @@ -2,6 +2,7 @@ from typing import Literal import torch +import torch.nn.functional as F from einops import repeat from torch import Tensor @@ -41,22 +42,38 @@ def to_dict(self, prefix: str = "") -> dict[str, float]: return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict} +def get_logprobs( + y_logits: Tensor, ensembling: Literal["none", "full"] = "none" +) -> Tensor: + """ + Get the class probabilities from a tensor of logits. + Args: + y_logits: Predicted log-odds of the positive class, tensor of shape (n, v). + Returns: + Tensor of logprobs: If ensemble is "none", a tensor of shape (n, v). + If ensemble is "full", a tensor of shape (n,). + """ + if ensembling == "full": + y_logits = y_logits.mean(dim=1) + return F.logsigmoid(y_logits) + + def evaluate_preds( y_true: Tensor, y_logits: Tensor, - ensembling: Literal["none", "partial", "full"] = "none", + ensembling: Literal["none", "full"] = "none", ) -> EvalResult: """ Evaluate the performance of a classification model. Args: y_true: Ground truth tensor of shape (N,). - y_logits: Predicted class tensor of shape (N, variants, n_classes). + y_logits: Predicted class tensor of shape (N, variants). Returns: dict: A dictionary containing the accuracy, AUROC, and ECE. """ - (n, v, c) = y_logits.shape + (n, v) = y_logits.shape assert y_true.shape == (n,) if ensembling == "full": @@ -64,46 +81,19 @@ def evaluate_preds( else: y_true = repeat(y_true, "n -> n v", v=v) - y_pred = y_logits.argmax(dim=-1) - if ensembling == "none": - auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1)) - elif ensembling in ("partial", "full"): - # Pool together the negative and positive class logits - if c == 2: - auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0]) - else: - auroc = roc_auc_ci(to_one_hot(y_true, c).long(), y_logits) - else: - raise ValueError(f"Unknown mode: {ensembling}") + y_pred = y_logits > 0 + auroc = roc_auc_ci(y_true.long(), y_logits) acc = accuracy_ci(y_true, y_pred) - cal_acc = None - cal_err = None - if c == 2: - pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0]) + pos_probs = torch.sigmoid(y_logits) - # Calibrated accuracy - cal_thresh = pos_probs.float().quantile(y_true.float().mean()) - cal_preds = pos_probs.gt(cal_thresh).to(torch.int) - cal_acc = accuracy_ci(y_true, cal_preds) + # Calibrated accuracy + cal_thresh = pos_probs.float().quantile(y_true.float().mean()) + cal_preds = pos_probs.gt(cal_thresh).to(torch.int) + cal_acc = accuracy_ci(y_true, cal_preds) - cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten()) - cal_err = cal.compute() + cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten()) + cal_err = cal.compute() return EvalResult(acc, cal_acc, cal_err, auroc) - - -def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: - """ - Convert a tensor of class labels to a one-hot representation. - - Args: - labels (Tensor): A tensor of class labels of shape (N,). - n_classes (int): The total number of unique classes. - - Returns: - Tensor: A one-hot representation tensor of shape (N, n_classes). - """ - one_hot_labels = labels.new_zeros(*labels.shape, n_classes) - return one_hot_labels.scatter_(-1, labels.unsqueeze(-1).long(), 1) diff --git a/elk/parsing.py b/elk/parsing.py deleted file mode 100644 index 1daded781..000000000 --- a/elk/parsing.py +++ /dev/null @@ -1,29 +0,0 @@ -import re - -from .training.losses import LOSSES - - -def parse_loss(terms: list[str]) -> dict[str, float]: - """Parse the loss command line argument list into a dictionary.""" - if len(terms) == 0: - return {"ccs_prompt_var": 1.0} - loss_dict = dict() - for term in terms: - if term in loss_dict: - raise ValueError(f"Duplicate loss term: {term}") - # check if the term is of the form "coef*name" - if re.match(r"^\d+(\.)?\d*\*\w+$", term): - coef, name = term.split("*") - coef = float(coef) - # check if the term is of the form "name" - elif re.match(r"^\w+$", term): - name = term - coef = 1.0 - else: - raise ValueError( - f"Invalid loss term: {term}. " - "Loss terms should be of the form 'coef*name' or 'name'." - ) - assert name in LOSSES, f"Unknown loss term: {name}" - loss_dict[name] = coef - return loss_dict diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index a82fd7d75..85eedd43d 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -20,7 +20,7 @@ def render( self, sweep: "SweepVisualization", with_transfer=False, - ensembles=["full", "partial", "none"], + ensembles=["full", "none"], write=False, ) -> go.Figure: """Render the multiplot visualization. @@ -219,7 +219,7 @@ def collect(cls, model_path: Path) -> "ModelVisualization": def get_train_dirs(model_path): # toplevel is either repo/dataset or dataset for toplevel in model_path.iterdir(): - if (toplevel / "eval.csv").exists(): + if (toplevel / "cfg.yaml").exists(): yield toplevel else: for train_dir in toplevel.iterdir(): @@ -272,7 +272,7 @@ def render_and_save( @staticmethod def _read_eval_csv(path, eval_dataset, train_dataset): - file = path / "eval.csv" + file = path / "lr_eval.csv" eval_df = pd.read_csv(file) eval_df["eval_dataset"] = eval_dataset eval_df["train_dataset"] = train_dataset @@ -382,7 +382,7 @@ def render_table( Returns: The generated score table as a pandas DataFrame. """ - df = self.df[self.df["ensembling"] == "partial"] + df = self.df[self.df["ensembling"] == "full"] # For each model, we use the layer whose mean AUROC is the highest best_layers, model_dfs = [], [] diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index 4d549abf9..4f94f698c 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -67,7 +67,9 @@ class Template(yaml.YAMLObject): yaml_tag = "!Template" - def __init__(self, name, jinja, reference, metadata=None, answer_choices=None): + def __init__( + self, name, jinja, reference, metadata=None, answer_choices=None, suffix="" + ): """ Creates a prompt template. @@ -88,13 +90,15 @@ def __init__(self, name, jinja, reference, metadata=None, answer_choices=None): be evaluated as ranked completions. If None, then the template is open-ended. This list is accessible from within Jinja as the variable `answer_choices`. + :param suffix: string to append to the end of the statement before the answer """ self.id = str(uuid.uuid4()) self.name = name self.jinja = jinja self.reference = reference - self.metadata = metadata if metadata is not None else Template.Metadata() + self.metadata = metadata or Template.Metadata() self.answer_choices = answer_choices + self.suffix = suffix def get_answer_choices_list(self, example): """ @@ -134,7 +138,7 @@ def get_fixed_answer_choices_list(self): else: return None - def apply(self, example, truncate=True, highlight_variables=False): + def apply(self, example, truncate=False, highlight_variables=False): """ Creates a prompt by applying this template to an example @@ -156,23 +160,28 @@ def apply(self, example, truncate=True, highlight_variables=False): jinja = jinja.replace("}}", " | highlight }}") rtemplate = env.from_string(jinja) + protected_example = self._escape_pipe(example) # Adds in answer_choices variable if "answer_choices" in protected_example: raise ValueError("Example contains the restricted key 'answer_choices'.") - protected_example["answer_choices"] = self.get_answer_choices_list(example) + try: + protected_example["answer_choices"] = self.get_answer_choices_list(example) + except AttributeError: + # there's no answer_choices field + pass # Renders the Jinja template rendered_example = rtemplate.render(**protected_example) # Splits on the separator, and then replaces back any occurrences of the # separator in the original example - return [ - Template._strip_spaces(self._unescape_pipe(part)) - for part in rendered_example.split("|||") - ] + statement_text, *_ = rendered_example.split("|||") + return ( + Template._strip_spaces(self._unescape_pipe(statement_text)) + " " + ) # THIS IS A HACK @staticmethod def _strip_spaces(string): @@ -260,6 +269,9 @@ def __init__(self, dataset_name: str, subset_name: str | None = None): # Required field; contains all the templates keyed by ID self.templates = yaml_dict["templates"] + for template in self.templates.values(): + if not hasattr(template, "suffix"): + template.suffix = "" self.binarize = yaml_dict.get("binarize", False) self.label_column = yaml_dict.get("label_column") diff --git a/elk/promptsource/templates/_default/templates.yaml b/elk/promptsource/templates/_default/templates.yaml new file mode 100644 index 000000000..3ca195587 --- /dev/null +++ b/elk/promptsource/templates/_default/templates.yaml @@ -0,0 +1,15 @@ +dataset: None +templates: + 7eab7254-bd71-4b1d-9f8a-0fc7110f8371: !Template + answer_choices: False ||| True + id: 7eab7254-bd71-4b1d-9f8a-0fc7110f8371 + jinja: "{{ statement }}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: _default + suffix: "\n\n\nQ: Is the above statement \"True\" or \"False\"?\n\nA:" diff --git a/elk/promptsource/templates/_no_suffix/templates.yaml b/elk/promptsource/templates/_no_suffix/templates.yaml new file mode 100644 index 000000000..7187f83c7 --- /dev/null +++ b/elk/promptsource/templates/_no_suffix/templates.yaml @@ -0,0 +1,15 @@ +dataset: None +templates: + 8eab7252-bd71-4b2d-9f8a-0fc7260f8371: !Template + answer_choices: False ||| True + id: 8eab7252-bd71-4b2d-9f8a-0fc7260f8371 + jinja: "{{ statement }}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: _no_suffix + suffix: "" diff --git a/elk/run.py b/elk/run.py index fb8903ccf..a621b89d7 100644 --- a/elk/run.py +++ b/elk/run.py @@ -31,6 +31,16 @@ ) +@dataclass +class LayerData: + hiddens: Tensor + labels: Tensor + lm_log_odds: Tensor | None + texts: list[list[str]] # (n, v) + row_ids: Tensor # (n,) + variant_ids: list[list[str]] # (n, v) + + @dataclass class Run(ABC, Serializable): data: Extract @@ -46,11 +56,24 @@ class Run(ABC, Serializable): prompt_indices: tuple[int, ...] = () """The indices of the prompt templates to use. If empty, all prompts are used.""" + save_logprobs: bool = field(default=False, to_dict=False) + """ saves logprobs.pt containing + {: {"row_ids": [n,], "variant_ids": [n, v], + "labels": [n,], "texts": [n, v], + "lm": {"none": [n, v], "full": [n,]}, + "lr": {: { + "none": {: [n, v], ...}, + "full": {: [n,], ...} + }, + ... + } + }} + """ + concatenated_layer_offset: int = 0 debug: bool = False - min_gpu_mem: int | None = None # in bytes num_gpus: int = -1 - out_dir: Path | None = None + min_gpu_mem: int = 0 disable_cache: bool = field(default=False, to_dict=False) def execute( @@ -64,7 +87,6 @@ def execute( disable_cache=self.disable_cache, highlight_color=highlight_color, num_gpus=self.num_gpus, - min_gpu_mem=self.min_gpu_mem, split_type=split_type, ) for cfg in self.data.explode() @@ -98,7 +120,7 @@ def execute( devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem) num_devices = len(devices) - func: Callable[[int], dict[str, pd.DataFrame]] = partial( + func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]] = partial( self.apply_to_layer, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) @@ -106,7 +128,7 @@ def execute( @abstractmethod def apply_to_layer( self, layer: int, devices: list[str], world_size: int - ) -> dict[str, pd.DataFrame]: + ) -> tuple[dict[str, pd.DataFrame], dict]: """Train or eval a reporter on a single layer.""" def make_reproducible(self, seed: int): @@ -125,7 +147,7 @@ def get_device(self, devices, world_size: int) -> str: def prepare_data( self, device: str, layer: int, split_type: Literal["train", "val"] - ) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]: + ) -> dict[str, LayerData]: """Prepare data for the specified layer and split type.""" out = {} @@ -134,15 +156,25 @@ def prepare_data( split = ds[key].with_format("torch", device=device, dtype=torch.int16) labels = assert_type(Tensor, split["label"]) + # hiddens shape: (num_examples, num_variants, hidden_d) hiddens = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) if self.prompt_indices: hiddens = hiddens[:, self.prompt_indices] - with split.formatted_as("torch", device=device): - has_preds = "model_logits" in split.features - lm_preds = split["model_logits"] if has_preds else None - - out[ds_name] = (hiddens, labels.to(hiddens.device), lm_preds) + if "lm_log_odds" in split.column_names: + with split.formatted_as("torch", device=device): + lm_preds = assert_type(Tensor, split["lm_log_odds"]) + else: + lm_preds = None + + out[ds_name] = LayerData( + hiddens=hiddens, + labels=labels, + lm_log_odds=lm_preds, + texts=split["texts"], + row_ids=assert_type(Tensor, split["row_id"]), + variant_ids=split["variant_ids"], + ) return out @@ -155,7 +187,7 @@ def concatenate(self, layers): def apply_to_layers( self, - func: Callable[[int], dict[str, pd.DataFrame]], + func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]], num_devices: int, ): """Apply a function to each layer of the datasets in parallel @@ -178,11 +210,19 @@ def apply_to_layers( with ctx.Pool(num_devices) as pool: mapper = pool.imap_unordered if num_devices > 1 else map df_buffers = defaultdict(list) + logprobs_dicts = defaultdict(dict) try: - for df_dict in tqdm(mapper(func, layers), total=len(layers)): + for df_dict, logprobs_dict in tqdm( + mapper(func, layers), total=len(layers) + ): + # get arbitrary value + df_ = next(iter(df_dict.values())) + layer = df_["layer"].iloc[0] for k, v in df_dict.items(): df_buffers[k].append(v) + for k, v in logprobs_dict.items(): + logprobs_dicts[k][layer] = logprobs_dict[k] finally: # Make sure the CSVs are written even if we crash or get interrupted for name, dfs in df_buffers.items(): @@ -190,3 +230,20 @@ def apply_to_layers( df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False) if self.debug: save_debug_log(self.datasets, self.out_dir) + if self.save_logprobs: + save_dict = defaultdict(dict) + for ds_name, logprobs_dict in logprobs_dicts.items(): + save_dict[ds_name]["row_ids"] = logprobs_dict[layers[0]][ + "row_ids" + ] + save_dict[ds_name]["texts"] = logprobs_dict[layers[0]]["texts"] + save_dict[ds_name]["labels"] = logprobs_dict[layers[0]][ + "labels" + ] + save_dict[ds_name]["lm"] = logprobs_dict[layers[0]]["lm"] + save_dict[ds_name]["lr"] = dict() + for layer, logprobs_dict_by_mode in logprobs_dict.items(): + save_dict[ds_name]["lr"][layer] = logprobs_dict_by_mode[ + "lr" + ] + torch.save(dict(save_dict), self.out_dir / "logprobs.pt") diff --git a/elk/training/__init__.py b/elk/training/__init__.py index 54a47b229..bbfb1cc92 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -1,15 +1,9 @@ -from .ccs_reporter import CcsConfig, CcsReporter from .classifier import Classifier from .common import FitterConfig -from .eigen_reporter import EigenFitter, EigenFitterConfig from .platt_scaling import PlattMixin __all__ = [ - "CcsReporter", - "CcsConfig", "Classifier", - "EigenFitter", - "EigenFitterConfig", "FitterConfig", "PlattMixin", ] diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py deleted file mode 100644 index cd161dd9b..000000000 --- a/elk/training/ccs_reporter.py +++ /dev/null @@ -1,291 +0,0 @@ -"""An ELK reporter network.""" - -import math -from copy import deepcopy -from dataclasses import dataclass, field -from typing import Literal, Optional, cast - -import torch -import torch.nn as nn -from concept_erasure import LeaceFitter -from torch import Tensor - -from ..parsing import parse_loss -from ..utils.typing import assert_type -from .common import FitterConfig -from .losses import LOSSES -from .platt_scaling import PlattMixin - - -@dataclass -class CcsConfig(FitterConfig): - activation: Literal["gelu", "relu", "swish"] = "gelu" - """The activation function to use.""" - bias: bool = True - """Whether to use a bias term in the linear layers.""" - hidden_size: Optional[int] = None - """ - The number of hidden units in the MLP. Defaults to None. By default, use an MLP - expansion ratio of 4/3. This ratio is used by Tucker et al. (2022) - in their 3-layer MLP probes. We could also use - a ratio of 4, imitating transformer FFNs, but this seems to lead to excessively - large MLPs when num_layers > 2. - """ - init: Literal["default", "pca", "spherical", "zero"] = "default" - """The initialization scheme to use.""" - loss: list[str] = field(default_factory=lambda: ["ccs"]) - """ - The loss function to use. list of strings, each of the form "coef*name", where coef - is a float and name is one of the keys in `elk.training.losses.LOSSES`. - Example: `--loss 1.0*consistency_squared 0.5*prompt_var` corresponds to the loss - function 1.0*consistency_squared + 0.5*prompt_var. - """ - loss_dict: dict[str, float] = field(default_factory=dict, init=False) - num_layers: int = 1 - """The number of layers in the MLP.""" - pre_ln: bool = False - """Whether to include a LayerNorm module before the first linear layer.""" - supervised_weight: float = 0.0 - """The weight of the supervised loss.""" - - lr: float = 1e-2 - """The learning rate to use. Ignored when `optimizer` is `"lbfgs"`.""" - num_epochs: int = 1000 - """The number of epochs to train for.""" - num_tries: int = 10 - """The number of times to try training the reporter.""" - optimizer: Literal["adam", "lbfgs"] = "lbfgs" - """The optimizer to use.""" - weight_decay: float = 0.01 - """The weight decay or L2 penalty to use.""" - - def __post_init__(self): - self.loss_dict = parse_loss(self.loss) - - # standardize the loss field - self.loss = [f"{coef}*{name}" for name, coef in self.loss_dict.items()] - - -class CcsReporter(nn.Module, PlattMixin): - """CCS reporter network. - - Args: - in_features: The number of input features. - cfg: The reporter configuration. - """ - - config: CcsConfig - - def __init__( - self, - cfg: CcsConfig, - in_features: int, - *, - device: str | torch.device | None = None, - dtype: torch.dtype | None = None, - num_variants: int = 1, - ): - super().__init__() - self.config = cfg - self.in_features = in_features - self.num_variants = num_variants - - # Learnable Platt scaling parameters - self.bias = nn.Parameter(torch.zeros(1, device=device, dtype=dtype)) - self.scale = nn.Parameter(torch.ones(1, device=device, dtype=dtype)) - - hidden_size = cfg.hidden_size or 4 * in_features // 3 - - self.norm = None - self.probe = nn.Sequential( - nn.Linear( - in_features, - 1 if cfg.num_layers < 2 else hidden_size, - bias=cfg.bias, - device=device, - ), - ) - if cfg.pre_ln: - self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) - - act_cls = { - "gelu": nn.GELU, - "relu": nn.ReLU, - "swish": nn.SiLU, - }[cfg.activation] - - for i in range(1, cfg.num_layers): - self.probe.append(act_cls()) - self.probe.append( - nn.Linear( - hidden_size, - 1 if i == cfg.num_layers - 1 else hidden_size, - bias=cfg.bias, - device=device, - ) - ) - - def reset_parameters(self): - """Reset the parameters of the probe. - - If init is "spherical", use the spherical initialization scheme. - If init is "default", use the default PyTorch initialization scheme for - nn.Linear (Kaiming uniform). - If init is "zero", initialize all parameters to zero. - """ - if self.config.init == "spherical": - # Mathematically equivalent to the unusual initialization scheme used in - # the original paper. They sample a Gaussian vector of dim in_features + 1, - # normalize to the unit sphere, then add an extra all-ones dimension to the - # input and compute the inner product. Here, we use nn.Linear with an - # explicit bias term, but use the same initialization. - assert len(self.probe) == 1, "Only linear probes can use spherical init" - probe = cast(nn.Linear, self.probe[0]) # Pylance gets the type wrong here - - theta = torch.randn(1, probe.in_features + 1, device=probe.weight.device) - theta /= theta.norm() - probe.weight.data = theta[:, :-1] - probe.bias.data = theta[:, -1] - - elif self.config.init == "default": - for layer in self.probe: - if isinstance(layer, nn.Linear): - layer.reset_parameters() - - elif self.config.init == "zero": - for param in self.parameters(): - param.data.zero_() - elif self.config.init != "pca": - raise ValueError(f"Unknown init: {self.config.init}") - - def forward(self, x: Tensor) -> Tensor: - """Return the credence assigned to the hidden state `x`.""" - assert self.norm is not None, "Must call fit() before forward()" - - raw_scores = self.probe(self.norm(x)).squeeze(-1) - return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - - def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: - """Return the loss of the reporter on the contrast pair (x0, x1). - - Args: - logit0: The raw score output of the reporter on x0. - logit1: The raw score output of the reporter on x1. - - Returns: - loss: The loss of the reporter on the contrast pair (x0, x1). - """ - loss = sum( - LOSSES[name](logit0, logit1, coef) - for name, coef in self.config.loss_dict.items() - ) - return assert_type(Tensor, loss) - - def fit(self, hiddens: Tensor) -> float: - """Fit the probe to the contrast pair `hiddens`. - - Returns: - best_loss: The best loss obtained. - """ - x_neg, x_pos = hiddens.unbind(2) - - # One-hot indicators for each prompt template - n, v, d = x_neg.shape - prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1) - - fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device) - fitter.update( - x=x_neg, - # Independent indicator for each (template, pseudo-label) pair - z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1), - ) - fitter.update( - x=x_pos, - # Independent indicator for each (template, pseudo-label) pair - z=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), - ) - self.norm = fitter.eraser - - x_neg, x_pos = self.norm(x_neg), self.norm(x_pos) - - # Record the best acc, loss, and params found so far - best_loss = torch.inf - best_state: dict[str, Tensor] = {} # State dict of the best run - - for i in range(self.config.num_tries): - self.reset_parameters() - - # This is sort of inefficient but whatever - if self.config.init == "pca": - diffs = torch.flatten(x_pos - x_neg, 0, 1) - _, __, V = torch.pca_lowrank(diffs, q=i + 1) - self.probe[0].weight.data = V[:, -1, None].T - - if self.config.optimizer == "lbfgs": - loss = self.train_loop_lbfgs(x_neg, x_pos) - elif self.config.optimizer == "adam": - loss = self.train_loop_adam(x_neg, x_pos) - else: - raise ValueError(f"Optimizer {self.config.optimizer} is not supported") - - if loss < best_loss: - best_loss = loss - best_state = deepcopy(self.state_dict()) - - if not math.isfinite(best_loss): - raise RuntimeError("Got NaN/infinite loss during training") - - self.load_state_dict(best_state) - return best_loss - - def train_loop_adam(self, x_neg: Tensor, x_pos: Tensor) -> float: - """Adam train loop, returning the final loss. Modifies params in-place.""" - - optimizer = torch.optim.AdamW( - self.parameters(), lr=self.config.lr, weight_decay=self.config.weight_decay - ) - - loss = torch.inf - for _ in range(self.config.num_epochs): - optimizer.zero_grad() - - # We already normalized in fit() - loss = self.loss(self(x_neg), self(x_pos)) - loss.backward() - optimizer.step() - - return float(loss) - - def train_loop_lbfgs(self, x_neg: Tensor, x_pos: Tensor) -> float: - """LBFGS train loop, returning the final loss. Modifies params in-place.""" - - optimizer = torch.optim.LBFGS( - self.parameters(), - line_search_fn="strong_wolfe", - max_iter=self.config.num_epochs, - tolerance_change=torch.finfo(x_pos.dtype).eps, - tolerance_grad=torch.finfo(x_pos.dtype).eps, - ) - # Raw unsupervised loss, WITHOUT regularization - loss = torch.inf - - def closure(): - nonlocal loss - optimizer.zero_grad() - - # We already normalized in fit() - loss = self.loss(self(x_neg), self(x_pos)) - regularizer = 0.0 - - # We explicitly add L2 regularization to the loss, since LBFGS - # doesn't have a weight_decay parameter - for param in self.parameters(): - regularizer += self.config.weight_decay * param.norm() ** 2 / 2 - - regularized = loss + regularizer - regularized.backward() - - return float(regularized) - - optimizer.step(closure) - return float(loss) diff --git a/elk/training/classifier.py b/elk/training/classifier.py index 148da939f..7b9281ec0 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field import torch +from concept_erasure import LeaceEraser from torch import Tensor from torch.nn.functional import ( binary_cross_entropy_with_logits as bce_with_logits, @@ -43,6 +44,7 @@ def __init__( self, input_dim: int, num_classes: int = 2, + eraser: LeaceEraser | None = None, device: str | torch.device | None = None, dtype: torch.dtype | None = None, ): @@ -53,8 +55,11 @@ def __init__( ) self.linear.bias.data.zero_() self.linear.weight.data.zero_() + self.eraser = eraser def forward(self, x: Tensor) -> Tensor: + if self.eraser is not None: + x = self.eraser(x) return self.linear(x).squeeze(-1) @torch.enable_grad() @@ -63,7 +68,7 @@ def fit( x: Tensor, y: Tensor, *, - l2_penalty: float = 0.0, + l2_penalty: float = 0.001, max_iter: int = 10_000, ) -> float: """Fits the model to the input data using L-BFGS with L2 regularization. @@ -185,7 +190,12 @@ def fit_cv( @classmethod def inlp( - cls, x: Tensor, y: Tensor, max_iter: int | None = None, tol: float = 0.01 + cls, + x: Tensor, + y: Tensor, + eraser: LeaceEraser | None = None, + max_iter: int | None = None, + tol: float = 0.01, ) -> InlpResult: """Iterative Nullspace Projection (INLP) . @@ -194,8 +204,9 @@ def inlp( the input dimension. y: Target tensor of shape (N,) for binary classification or (N, C) for multiclass classification, where C is the number of classes. - max_iter: Maximum number of iterations to run. If `None`, run for the full - dimension of the input. + eraser: Concept erasure function to use. If `None`, no erasure is performed. + max_iter: Maximum number of iterations to run. If `None`, run until the data + is linearly guarded (no linear classifier can extract information) tol: Tolerance for the loss function. The algorithm will stop when the loss is within `tol` of the entropy of the labels. @@ -212,13 +223,13 @@ def inlp( p = y.float().mean() H = -p * torch.log(p) - (1 - p) * torch.log(1 - p) - if max_iter is not None: - d = min(d, max_iter) + max_iter = max_iter or d # Iterate until the loss is within epsilon of the entropy + # meaning LR is not able to find a useful classifier anymore result = InlpResult() - for _ in range(d): - clf = cls(d, device=x.device, dtype=x.dtype) + for _ in range(max_iter): + clf = cls(d, eraser=eraser, device=x.device, dtype=x.dtype) loss = clf.fit(x, y) result.classifiers.append(clf) result.losses.append(loss) diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py deleted file mode 100644 index a3525b1d1..000000000 --- a/elk/training/eigen_reporter.py +++ /dev/null @@ -1,237 +0,0 @@ -"""An ELK reporter network.""" - -from dataclasses import dataclass - -import torch -from concept_erasure import LeaceFitter -from einops import rearrange -from torch import Tensor - -from ..utils.math_util import cov_mean_fused -from .common import FitterConfig, Reporter - - -@dataclass -class EigenFitterConfig(FitterConfig): - """Configuration for an EigenFitter.""" - - var_weight: float = 0.0 - """The weight of the variance term in the loss.""" - - neg_cov_weight: float = 0.5 - """The weight of the negative covariance term in the loss.""" - - num_heads: int = 1 - """The number of eigenvectors to compute from the VINC matrix.""" - - save_reporter_stats: bool = False - """Whether to save the reporter statistics to disk in EigenFitter.save(). This - is useful for debugging and analysis, but can take up a lot of disk space.""" - - erase_prompts: bool = False - """Whether to apply concept erasure on the prompt template IDs.""" - - use_centroids: bool = True - """Whether to average hiddens within each cluster before computing covariance.""" - - def __post_init__(self): - if not (0 <= self.neg_cov_weight <= 1): - raise ValueError("neg_cov_weight must be in [0, 1]") - if self.num_heads <= 0: - raise ValueError("num_heads must be positive") - - -class EigenFitter: - """Fit a linear reporter with eigendecomposition. - - Args: - cfg: The reporter configuration. - in_features: The number of input features. - num_classes: The number of classes for tracking the running means. - - Attributes: - config: The reporter configuration. - intercluster_cov_M2: The unnormalized covariance matrix averaged over all - classes. - intracluster_cov: The running mean of the covariance matrices within each - cluster. This doesn't need to be a running sum because it's doesn't use - Welford's algorithm. - contrastive_xcov_M2: Average of the unnormalized cross-covariance matrices - across all pairs of classes (k, k'). - n: The running sum of the number of clusters processed by `update()`. - weight: The reporter weight matrix. Guaranteed to always be orthogonal, and - the columns are sorted in descending order of eigenvalue magnitude. - """ - - config: EigenFitterConfig - - intercluster_cov_M2: Tensor # variance - intracluster_cov: Tensor # invariance - contrastive_xcov_M2: Tensor # negative covariance - - n: Tensor - class_means: Tensor - weight: Tensor - - def __init__( - self, - cfg: EigenFitterConfig, - in_features: int, - num_classes: int = 2, - *, - device: str | torch.device | None = None, - dtype: torch.dtype | None = None, - num_variants: int = 1, - ): - super().__init__() - self.config = cfg - self.in_features = in_features - self.num_classes = num_classes - self.num_variants = num_variants - - self.leace = LeaceFitter( - in_features, - num_classes * num_variants if cfg.erase_prompts else num_classes, - device=device, - dtype=dtype, - ) - - # Running statistics - self.n = torch.zeros((), device=device, dtype=torch.long) - self.class_means = torch.zeros( - num_classes, in_features, device=device, dtype=dtype - ) - self.contrastive_xcov_M2 = torch.zeros( - in_features, in_features, device=device, dtype=dtype - ) - self.intercluster_cov_M2 = torch.zeros( - in_features, in_features, device=device, dtype=dtype - ) - self.intracluster_cov = torch.zeros( - in_features, in_features, device=device, dtype=dtype - ) - - @property - def contrastive_xcov(self) -> Tensor: - assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?" - return self.contrastive_xcov_M2 / self.n - - @property - def intercluster_cov(self) -> Tensor: - assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?" - return self.intercluster_cov_M2 / self.n - - @property - def confidence(self) -> Tensor: - return self.weight @ self.intercluster_cov @ self.weight.mT - - @property - def invariance(self) -> Tensor: - assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?" - return -self.weight @ self.intracluster_cov @ self.weight.mT - - @property - def consistency(self) -> Tensor: - return -self.weight @ self.contrastive_xcov @ self.weight.mT - - @torch.no_grad() - def update(self, hiddens: Tensor) -> None: - (n, v, k, d) = hiddens.shape - - # Sanity checks - assert k > 1, "Must provide at least two hidden states" - assert hiddens.ndim == 4, "Must be of shape [batch, variants, choices, dim]" - - self.n += n - - if self.config.erase_prompts: - # Independent indicator for each (template, pseudo-label) pair - indicators = torch.eye(k * v, device=hiddens.device).expand(n, -1, -1) - self.leace.update(x=hiddens, z=indicators) - else: - # Only use indicators for each pseudo-label - indicators = torch.eye(k, device=hiddens.device).expand(n, v, -1, -1) - - self.leace.update(x=hiddens, z=indicators) - - # *** Invariance (intra-cluster) *** - # This is just a standard online *mean* update, since we're computing the - # mean of covariance matrices, not the covariance matrix of means. - intra_cov = cov_mean_fused(rearrange(hiddens, "n v k d -> (n k) v d")) - self.intracluster_cov += (n / self.n) * (intra_cov - self.intracluster_cov) - - if self.config.use_centroids: - # VINC style - centroids = hiddens.mean(1) - else: - # CRC-TPC style - centroids = rearrange(hiddens, "n v k d -> (n v) k d") - - deltas, deltas2 = [], [] - - # Iterating over classes - for i, h in enumerate(centroids.unbind(1)): - # Update the running means - delta = h - self.class_means[i] - self.class_means[i] += delta.sum(dim=0) / self.n - - # Post-mean update deltas are used to update the (co)variance - delta2 = h - self.class_means[i] # [n, d] - - # *** Variance (inter-cluster) *** - # See code at https://bit.ly/3YC9BhH and "Welford's online algorithm" - # in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. - self.intercluster_cov_M2.addmm_(delta.mT, delta2, alpha=1 / k) - deltas.append(delta) - deltas2.append(delta2) - - # *** Negative covariance (contrastive) *** - # Iterating over pairs of classes (k, k') where k != k' - for i, d in enumerate(deltas): - for j, d_ in enumerate(deltas2): - # Compare to the other classes only - if i == j: - continue - - scale = 1 / (k * (k - 1)) - self.contrastive_xcov_M2.addmm_(d.mT, d_, alpha=scale) - - def fit_streaming(self) -> Reporter: - """Fit the probe using the current streaming statistics.""" - inv_weight = 1 - self.config.neg_cov_weight - A = ( - self.config.var_weight * self.intercluster_cov - - inv_weight * self.intracluster_cov - - self.config.neg_cov_weight * self.contrastive_xcov - ) - - # Remove the subspace responsible for pseudolabel correlations - A = self.leace.eraser.P @ A @ self.leace.eraser.P.mT - try: - L, Q = torch.linalg.eigh(A) - except torch.linalg.LinAlgError: - try: - L, Q = torch.linalg.eig(A) - L, Q = L.real, Q.real - except torch.linalg.LinAlgError as e: - # Check if the matrix has non-finite values - if not A.isfinite().all(): - raise ValueError( - "Fitting the reporter failed because the VINC matrix has " - "non-finite entries. Usually this means the hidden states " - "themselves had non-finite values." - ) from e - else: - raise e - - L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :] - return Reporter(Q.T, self.leace.eraser) - - def fit(self, hiddens: Tensor) -> Reporter: - """Fit the probe to the contrast set `hiddens`. - - Args: - hiddens: The contrast set of shape [batch, variants, choices, dim]. - """ - self.update(hiddens) - return self.fit_streaming() diff --git a/elk/training/losses.py b/elk/training/losses.py deleted file mode 100644 index 8d7e287be..000000000 --- a/elk/training/losses.py +++ /dev/null @@ -1,156 +0,0 @@ -"""Loss functions for training reporters.""" - -import warnings -from inspect import signature - -import torch -from torch import Tensor - -LOSSES = dict() # Registry of loss functions - - -def register(name): - """A decorator to register a function to LOSSES""" - - def decorate(func): - assert signature(func).parameters.keys() == {"logit0", "logit1", "coef"}, ( - f"Loss function {func.__name__} must take arguments " - "`logit0`, `logit1`, and `coef`." - ) - assert ( - name not in LOSSES - ), f"Loss function {name} conflicts with existing function." - LOSSES[name] = func - return func - - return decorate - - -def H(p: Tensor) -> Tensor: - """Entropy of Bernoulli distribution(s) with success probability `p`.""" - return torch.nn.functional.binary_cross_entropy(p, p) - - -@register("ccs") -def ccs_squared_loss(logit0: Tensor, logit1: Tensor, coef: float = 1.0) -> Tensor: - """CCS loss from original paper, with squared differences between probabilities. - - The loss is symmetric, so it doesn't matter which argument is the original and - which is the negated proposition. - - Args: - logit0: The log odds for the original proposition. - logit1: The log odds for the negated proposition. - coef: The coefficient to multiply the loss by. - Returns: - The sum of the consistency and confidence losses. - """ - loss = consistency_squared_loss(logit0, logit1) + confidence_squared_loss( - logit0, logit1 - ) - return coef * loss - - -@register("ccs_prompt_var") -def ccs_prompt_var_loss(logit0: Tensor, logit1: Tensor, coef: float = 1.0) -> Tensor: - """CCS loss with prompt variance regularization. - - The loss is symmetric, so it doesn't matter which argument is the original and - which is the negated proposition. - - Args: - logit0: The log odds for the original proposition. Shape ([batch,] n_variants) - logit1: The log odds for the negated proposition. Shape ([batch,] n_variants) - coef: The coefficient to multiply the loss by. - Returns: - The sum of the consistency and confidence losses. - """ - loss = ( - consistency_squared_loss(logit0, logit1) - + confidence_squared_loss(logit0, logit1) - + prompt_var_loss(logit0, logit1) - ) - return coef * loss - - -@register("js") -def js_loss( - logit0: Tensor, - logit1: Tensor, - coef: float = 1.0, -) -> Tensor: - """Negation consistency loss based on the Jensen-Shannon divergence. - - Note that by default we use the base 2 logarithm, so the value is measured in bits. - This ensures the divergence is in the range [0, 1].""" - p0, neg_p1 = logit0.sigmoid(), 1 - logit1.sigmoid() - nats = H((p0 + neg_p1) / 2) - (H(p0) + H(neg_p1)) / 2 - return coef * nats - - -@register("js_confidence") -def js_confidence_loss( - logit0: Tensor, - logit1: Tensor, - coef: float = 1.0, -) -> Tensor: - """Confidence loss based on the Jensen-Shannon divergence. This is the same as the - entropy of the 50/50 mixture of the two distributions. - - Note that by default we use the base 2 logarithm, so the value is measured in bits. - This ensures the divergence is in the range [0, 1].""" - p0, neg_p1 = logit0.sigmoid(), 1 - logit1.sigmoid() - nats = H((p0 + neg_p1) / 2) - return coef * nats - - -@register("consistency_squared") -def consistency_squared_loss( - logit0: Tensor, - logit1: Tensor, - coef: float = 1.0, -) -> Tensor: - """Negation consistency loss based on the squared difference between the - two distributions.""" - p0, p1 = logit0.sigmoid(), logit1.sigmoid() - return coef * p0.sub(1 - p1).square().mean() - - -@register("confidence_squared") -def confidence_squared_loss( - logit0: Tensor, - logit1: Tensor, - coef: float = 1.0, -) -> Tensor: - """Confidence loss based on the squared difference between the two distributions.""" - p0, p1 = logit0.sigmoid(), logit1.sigmoid() - return coef * torch.min(p0, p1).square().mean() - - -@register("prompt_var_squared") -def prompt_var_loss(logit0: Tensor, logit1: Tensor, coef: float = 1.0) -> Tensor: - """ - Prompt-variance loss: the squared difference between the probability - of a proposition and the mean probability over all variants of that - proposition (templates). - - The loss is symmetric, so it doesn't matter which argument is the original and - which is the negated proposition. - - Args: - logit0: The log odds for the original proposition. shape ([batch,] n_variants) - logit1: The log odds for the negated proposition. shape ([batch,] n_variants) - coef: The coefficient to multiply the loss by. - """ - assert logit0.shape == logit1.shape - assert len(logit0.shape) in [1, 2] - if logit0.shape[-1] == 1: - warnings.warn( - "Only one variant provided. Prompt variance loss will cause errors." - ) - p0, p1 = logit0.sigmoid(), logit1.sigmoid() - - var0 = p0.var(dim=-1, unbiased=False).mean() - var1 = p1.var(dim=-1, unbiased=False).mean() - prompt_variance = var0 + var1 - return coef * prompt_variance diff --git a/elk/training/supervised.py b/elk/training/supervised.py index d2eef5f7f..b3f100646 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -1,34 +1,61 @@ import torch +from concept_erasure import LeaceFitter from einops import rearrange, repeat -from ..metrics import to_one_hot +from ..run import LayerData from .classifier import Classifier def train_supervised( - data: dict[str, tuple], device: str, mode: str + data: dict[str, LayerData], + device: str, + mode: str, + erase_paraphrases: bool = False, + max_inlp_iter: int | None = None, ) -> list[Classifier]: + assert not ( + erase_paraphrases and len(data) > 1 + ), "Erasing paraphrases is only supported for single dataset." Xs, train_labels = [], [] - for train_h, labels, _ in data.values(): - (_, v, k, _) = train_h.shape - train_h = rearrange(train_h, "n v k d -> (n v k) d") + leace = None - labels = repeat(labels, "n -> (n v)", v=v) - labels = to_one_hot(labels, k).flatten() + for train_data in data.values(): + (n, v, d) = train_data.hiddens.shape + train_h = rearrange(train_data.hiddens, "n v d -> (n v) d") + + if erase_paraphrases and v > 1: + if leace is None: + leace = LeaceFitter( + d, + v, + device=device, + dtype=train_h.dtype, + ) + # indicators = [0, 1, ..., v-1, 0, 1, ..., v-1, ...] to one-hot + indicators = torch.eye(v, device=device, dtype=train_h.dtype).repeat( + n, 1 + ) # (n * v, v) + leace = leace.update(train_h, indicators) + + labels = repeat(train_data.labels, "n -> (n v)", v=v) Xs.append(train_h) train_labels.append(labels) X, train_labels = torch.cat(Xs), torch.cat(train_labels) + eraser = leace.eraser if leace is not None else None + if mode == "cv": - lr_model = Classifier(X.shape[-1], device=device) + lr_model = Classifier(X.shape[-1], device=device, eraser=eraser) lr_model.fit_cv(X, train_labels) return [lr_model] elif mode == "inlp": - return Classifier.inlp(X, train_labels).classifiers + return Classifier.inlp( + X, train_labels, eraser=eraser, max_iter=max_inlp_iter + ).classifiers elif mode == "single": - lr_model = Classifier(X.shape[-1], device=device) + lr_model = Classifier(X.shape[-1], device=device, eraser=eraser) lr_model.fit(X, train_labels) return [lr_model] else: diff --git a/elk/training/sweep.py b/elk/training/sweep.py index e4aca5a00..52e3d664f 100755 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -1,6 +1,5 @@ -from dataclasses import InitVar, dataclass, replace +from dataclasses import InitVar, dataclass, field, replace -import numpy as np import torch from datasets import get_dataset_config_info from transformers import AutoConfig @@ -9,7 +8,6 @@ from ..extraction import Extract from ..files import memorably_named_dir, sweeps_dir from ..plotting.visualize import visualize_sweep -from ..training.eigen_reporter import EigenFitterConfig from ..utils import colorize from ..utils.constants import BURNS_DATASETS from .train import Elicit @@ -38,11 +36,6 @@ class Sweep: add_pooled: InitVar[bool] = False """Whether to add a dataset that pools all of the other datasets together.""" - hparam_step: float = -1.0 - """The step size for hyperparameter sweeps. Performs a 2D - sweep over a and b in (var_weight, inv_weight, neg_cov_weight) = (a, 1 - b, b) - If negative, no hyperparameter sweeps will be performed. Only valid for Eigen.""" - skip_transfer_eval: bool = False """Whether to perform transfer eval on every pair of datasets.""" @@ -52,25 +45,13 @@ class Sweep: name: str | None = None # A bit of a hack to add all the command line arguments from Elicit - run_template: Elicit = Elicit( - data=Extract( - model="", - datasets=("",), - ) - ) + run_template: Elicit = field(default_factory=Elicit.default) def __post_init__(self, add_pooled: bool): if not self.datasets: raise ValueError("No datasets specified") if not self.models: raise ValueError("No models specified") - # can only use hparam_step if we're using an eigen net - if self.hparam_step > 0 and not isinstance( - self.run_template.net, EigenFitterConfig - ): - raise ValueError("Can only use hparam_step with EigenFitterConfig") - elif self.hparam_step > 1: - raise ValueError("hparam_step must be in [0, 1]") # Check for the magic dataset "burns" which is a shortcut for all of the # datasets used in Burns et al., except Story Cloze, which is not available @@ -115,9 +96,6 @@ def execute(self): } ) - step = self.hparam_step - weights = np.arange(0.0, 1.0 + step, step) if step > 0 else [None] - for i, model in enumerate(self.models): print(colorize(f"===== {model} ({i + 1} of {M}) =====", "magenta")) @@ -127,52 +105,38 @@ def execute(self): # single sweep. train_datasets = tuple(ds.strip() for ds in dataset_str.split("+")) - for var_weight in weights: - for neg_cov_weight in weights: - out_dir = sweep_dir / model / dataset_str + out_dir = sweep_dir / model / dataset_str - data = replace( - self.run_template.data, model=model, datasets=train_datasets - ) - run = replace(self.run_template, data=data, out_dir=out_dir) - if var_weight is not None and neg_cov_weight is not None: - assert isinstance(run.net, EigenFitterConfig) - run.net.var_weight = var_weight - run.net.neg_cov_weight = neg_cov_weight - - # Add hyperparameter values to output directory if needed - assert run.out_dir is not None - run.out_dir /= f"var_weight={var_weight:.2f}" - run.out_dir /= f"neg_cov_weight={neg_cov_weight:.2f}" - - try: - run.execute() - except torch.linalg.LinAlgError as e: - print(colorize(f"LinAlgError: {e}", "red")) + data = replace( + self.run_template.data, model=model, datasets=train_datasets + ) + run = replace(self.run_template, data=data, out_dir=out_dir) + + # Add hyperparameter values to output directory if needed + assert run.out_dir is not None + + run.execute() + + if not self.skip_transfer_eval: + if len(eval_datasets) > 1: + print(colorize("== Transfer eval ==", "green")) + + # Now evaluate the reporter on the other datasets + for eval_dataset in eval_datasets: + # We already evaluated on this one during training + if eval_dataset in train_datasets: continue - if not self.skip_transfer_eval: - if len(eval_datasets) > 1: - print(colorize("== Transfer eval ==", "green")) - - # Now evaluate the reporter on the other datasets - for eval_dataset in eval_datasets: - # We already evaluated on this one during training - if eval_dataset in train_datasets: - continue - - assert run.out_dir is not None - eval = Eval( - data=replace( - run.data, model=model, datasets=(eval_dataset,) - ), - source=run.out_dir, - out_dir=run.out_dir / "transfer" / eval_dataset, - num_gpus=run.num_gpus, - min_gpu_mem=run.min_gpu_mem, - skip_supervised=run.supervised == "none", - ) - eval.execute(highlight_color="green") + assert run.out_dir is not None + eval = Eval( + data=replace( + run.data, model=model, datasets=(eval_dataset,) + ), + source=run.out_dir, + out_dir=run.out_dir / "transfer" / eval_dataset, + num_gpus=run.num_gpus, + ) + eval.execute(highlight_color="green") if self.visualize: visualize_sweep(sweep_dir) diff --git a/elk/training/train.py b/elk/training/train.py index 938fcf79e..baa21991f 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -7,181 +7,145 @@ import pandas as pd import torch -from einops import rearrange, repeat -from simple_parsing import subgroups -from simple_parsing.helpers.serialization import save -from ..metrics import evaluate_preds, to_one_hot +from ..extraction import Extract +from ..metrics import evaluate_preds, get_logprobs from ..run import Run from ..training.supervised import train_supervised from ..utils.typing import assert_type -from .ccs_reporter import CcsConfig, CcsReporter -from .common import FitterConfig -from .eigen_reporter import EigenFitter, EigenFitterConfig @dataclass class Elicit(Run): """Full specification of a reporter training run.""" - net: FitterConfig = subgroups( - {"ccs": CcsConfig, "eigen": EigenFitterConfig}, default="eigen" # type: ignore - ) - """Config for building the reporter network.""" + seed: int = 42 - supervised: Literal["none", "single", "inlp", "cv"] = "single" + supervised: Literal["single", "inlp", "cv"] = "single" """Whether to train a supervised classifier, and if so, whether to use cross-validation. Defaults to "single", which means to train a single classifier on the training data. "cv" means to use cross-validation.""" + erase_paraphrases: bool = False + """Whether to use LEACE to erase the paraphrase dimensions before training the + classifier.""" + + max_inlp_iter: int | None = None + """Maximum number of iterations for Iterative Nullspace Projection (INLP).""" + + @staticmethod + def default(): + return Elicit( + data=Extract( + model="", + datasets=("",), + ) + ) + def create_models_dir(self, out_dir: Path): - lr_dir = None lr_dir = out_dir / "lr_models" - reporter_dir = out_dir / "reporters" lr_dir.mkdir(parents=True, exist_ok=True) - reporter_dir.mkdir(parents=True, exist_ok=True) - # Save the reporter config separately in the reporter directory - # for convenient loading of reporters later. - save(self.net, reporter_dir / "cfg.yaml", save_dc_types=True) - - return reporter_dir, lr_dir + return lr_dir def apply_to_layer( self, layer: int, devices: list[str], world_size: int, - ) -> dict[str, pd.DataFrame]: + ) -> tuple[dict[str, pd.DataFrame], dict]: """Train a single reporter on a single layer.""" - self.make_reproducible(seed=self.net.seed + layer) + self.make_reproducible(seed=self.seed + layer) device = self.get_device(devices, world_size) train_dict = self.prepare_data(device, layer, "train") val_dict = self.prepare_data(device, layer, "val") - (first_train_h, train_gt, _), *rest = train_dict.values() - (_, v, k, d) = first_train_h.shape - if not all(other_h.shape[-1] == d for other_h, _, _ in rest): + first_train_data, *rest = train_dict.values() + (_, v, d) = first_train_data.hiddens.shape + if not all(other_data.hiddens.shape[-1] == d for other_data in rest): raise ValueError("All datasets must have the same hidden state size") - # For a while we did support datasets with different numbers of classes, but - # we reverted this once we switched to ConceptEraser. There are a few options - # for re-enabling it in the future but they are somewhat complex and it's not - # clear that it's worth it. - if not all(other_h.shape[-2] == k for other_h, _, _ in rest): - raise ValueError("All datasets must have the same number of classes") - - reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - train_loss = None - - if isinstance(self.net, CcsConfig): - assert len(train_dict) == 1, "CCS only supports single-task training" - - reporter = CcsReporter(self.net, d, device=device, num_variants=v) - train_loss = reporter.fit(first_train_h) - - (_, v, k, _) = first_train_h.shape - reporter.platt_scale( - to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), - rearrange(first_train_h, "n v k d -> (n v k) d"), - ) - - elif isinstance(self.net, EigenFitterConfig): - fitter = EigenFitter( - self.net, d, num_classes=k, num_variants=v, device=device - ) - - hidden_list, label_list = [], [] - for ds_name, (train_h, train_gt, _) in train_dict.items(): - (_, v, _, _) = train_h.shape - - # Datasets can have different numbers of variants, so we need to - # flatten them here before concatenating - hidden_list.append(rearrange(train_h, "n v k d -> (n v k) d")) - label_list.append( - to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten() - ) - fitter.update(train_h) - - reporter = fitter.fit_streaming() - reporter.platt_scale( - torch.cat(label_list), - torch.cat(hidden_list), - ) - else: - raise ValueError(f"Unknown reporter config type: {type(self.net)}") - - # Save reporter checkpoint to disk - torch.save(reporter, reporter_dir / f"layer_{layer}.pt") + lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) # Fit supervised logistic regression model - if self.supervised != "none": - lr_models = train_supervised( - train_dict, - device=device, - mode=self.supervised, - ) - with open(lr_dir / f"layer_{layer}.pt", "wb") as file: - torch.save(lr_models, file) - else: - lr_models = [] + lr_models = train_supervised( + train_dict, + erase_paraphrases=self.erase_paraphrases, + device=device, + mode=self.supervised, + max_inlp_iter=self.max_inlp_iter, + ) + with open(lr_dir / f"layer_{layer}.pt", "wb") as file: + torch.save(lr_models, file) + + out_logprobs = defaultdict(dict) row_bufs = defaultdict(list) for ds_name in val_dict: - val_h, val_gt, val_lm_preds = val_dict[ds_name] - train_h, train_gt, train_lm_preds = train_dict[ds_name] + val, train = val_dict[ds_name], train_dict[ds_name] meta = {"dataset": ds_name, "layer": layer} - val_credences = reporter(val_h) - train_credences = reporter(train_h) - for mode in ("none", "partial", "full"): - row_bufs["eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_credences, mode).to_dict(), - "train_loss": train_loss, - } + if self.save_logprobs: + out_logprobs[ds_name] = dict( + row_ids=val.row_ids.cpu(), + variant_ids=val.variant_ids, + texts=val.texts, + labels=val.labels.cpu(), + lm=dict(), + lr=dict(), ) - row_bufs["train_eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(train_gt, train_credences, mode).to_dict(), - "train_loss": train_loss, - } - ) + for mode in ("none", "full"): + if val.lm_log_odds is not None: + if self.save_logprobs: + out_logprobs[ds_name]["lm"][mode] = ( + get_logprobs(val.lm_log_odds, mode).detach().cpu() + ) - if val_lm_preds is not None: row_bufs["lm_eval"].append( { **meta, "ensembling": mode, - **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), + **evaluate_preds( + val.labels, val.lm_log_odds, mode + ).to_dict(), } ) - if train_lm_preds is not None: - row_bufs["train_lm_eval"].append( + if self.save_logprobs: + out_logprobs[ds_name]["lr"][mode] = dict() + + for i, model in enumerate(lr_models): + model.eval() + val_log_odds = model(val.hiddens) + train_log_odds = model(train.hiddens) + + if self.save_logprobs: + out_logprobs[ds_name]["lr"][mode][i] = ( + get_logprobs(val_log_odds, mode).detach().cpu() + ) + + row_bufs["lr_eval"].append( { **meta, "ensembling": mode, - **evaluate_preds(train_gt, train_lm_preds, mode).to_dict(), + "inlp_iter": i, + **evaluate_preds(val.labels, val_log_odds, mode).to_dict(), } ) - for i, model in enumerate(lr_models): - row_bufs["lr_eval"].append( + row_bufs["train_lr_eval"].append( { **meta, "ensembling": mode, "inlp_iter": i, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + **evaluate_preds( + train.labels, train_log_odds, mode + ).to_dict(), } ) - return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs diff --git a/elk/truncated_eigh.py b/elk/truncated_eigh.py deleted file mode 100644 index fe2298375..000000000 --- a/elk/truncated_eigh.py +++ /dev/null @@ -1,228 +0,0 @@ -from typing import Literal, NamedTuple, Optional - -import torch -import torch.nn.functional as F -from torch import Tensor - - -class ConvergenceError(Exception): - """Raised when the Lanczos iteration fails to converge.""" - - -class Eigendecomposition(NamedTuple): - """A namedtuple containing eigenpairs of a matrix.""" - - eigenvalues: Tensor - eigenvectors: Tensor - - -def truncated_eigh( - A: Tensor, - k: int = 1, - *, - max_iter: Optional[int] = None, - ncv: Optional[int] = None, - tol: float = 1e-3, - seed: Optional[int] = None, - which: Literal["LA", "SA"] = "LA", - verbose: bool = False, -) -> Eigendecomposition: - """Compute the leading `k` eigenpairs of `A` with the thick-restart Lanczos method. - - Algorithm proposed by Wu & Simon (1998) https://www.osti.gov/servlets/purl/895499. - For matrices 256 x 256 or smaller, we short-circuit to the naive method of calling - `torch.linalg.eigh` and discarding all but the requested number of eigenpairs. - Empirically this is faster than our Lanczos implementation for such small matrices. - - Args: - A (Tensor): The matrix or batch of matrices of shape `[..., n, n]` for which to - compute eigenpairs. Must be symmetric, but need not be positive definite. - k (int): The number of eigenpairs to compute. - max_iter (int, optional): The maximum number of iterations to perform. - ncv (int, optional): The number of Lanczos vectors generated. Must be - greater than k and smaller than n - 1. - tol (float, optional): The tolerance for the residual. Defaults to the machine - precision of `A.dtype` or 1e-4, whichever is larger. - seed (int, optional): The random seed to use for the starting vector. - which (str, optional): Which k eigenvalues and eigenvectors to compute. - Must be one of 'LA', or 'SA'. - 'LA': compute the k largest (algebraic) eigenvalues. - 'SA': compute the k smallest (algebraic) eigenvalues. - verbose (bool, optional): Whether to print progress information. - - Returns: - (Tensor, Tensor): A tuple containing the eigenvalues and eigenvectors. - - Raises: - ConvergenceError: If the Lanczos iteration fails to converge. - """ - *leading, n, m = A.shape - assert n == m, "A must be a square matrix or a batch of square matrices." - - # Short circuit if the matrix is too small or if we're asked for too many - # eigenpairs; we can't outcompete the naive method. - if k > 10 or n <= 256: - L, Q = torch.linalg.eigh(A) - if which == "LA": - return Eigendecomposition(L[..., -k:], Q[..., :, -k:]) - elif which == "SA": - return Eigendecomposition(L[..., :k], Q[..., :, :k]) - - if ncv is None: - # This is the default used by SciPy; CuPy uses min(n - 1, max(2 * k, k + 32)). - # Empirically the SciPy default seems to converge better. - ncv = min(n, max(2 * k + 1, 20)) - else: - ncv = min(max(ncv, k + 2), n - 1) - - if max_iter is None: - max_iter = 10 * n - - # Diagonal and off-diagonal elements of the tridiagonal matrix - alpha = A.new_zeros([*leading, ncv]) - beta = A.new_zeros([*leading, ncv]) - - # Lanczos vector basis for the Krylov subspace - Q = A.new_empty([*leading, ncv, n]) - - # Initialize Lanczos vector - rng = torch.Generator(A.device) - if seed is not None: - rng.manual_seed(seed) - - r_k = torch.randn(*leading, n, dtype=A.dtype, device=A.device, generator=rng) - - Q[..., 0, :] = F.normalize(r_k, dim=-1) - _lanczos_inner_loop(A, Q, r_k, alpha, beta, 0, ncv) - - # Compute the Ritz vectors and values - cur_iter = ncv - w, s = _solve_ritz_pairs(alpha, beta, None, k, which) - x = torch.einsum("...ij,...ik->...jk", Q, s) - - # Compute the residual. Note that we take the max over the batch dimensions, - # to ensure that we don't terminate early for any element in the batch. - beta_k = beta[..., -1, None] * s[..., -1, :] - first_res = res = beta_k.norm(dim=-1).max() - - # Keep restarting until we converge or hit the iteration limit - while res > tol and cur_iter < max_iter: - # Setup for thick-restart - alpha[..., :k] = w - beta[..., :k] = 0 - Q[..., :k, :] = x.mT - - # Compute the next Lanczos vector - _gram_schmidt(r_k, Q[..., :k, :]) - Q[..., k, :] = F.normalize(r_k, dim=-1) - - r_k[:] = torch.einsum("...ij,...j->...i", A, Q[..., k, :]) - alpha[..., k] = torch.einsum("...i,...i->...", Q[..., k, :], r_k) - _gram_schmidt(r_k, Q[..., : k + 1, :]) - - beta[..., k] = r_k.square().sum(dim=-1).sqrt() # TorchScript-friendly norm - Q[..., k + 1, :] = r_k / beta[..., k, None] - - # Inner loop - _lanczos_inner_loop(A, Q, r_k, alpha, beta, k + 1, ncv) - - w, s = _solve_ritz_pairs(alpha, beta, beta_k, k, which) - x = Q.mT @ s - - # Compute the residual - beta_k = beta[..., -1, None] * s[..., -1, :] - new_res = beta_k.square().sum(dim=-1).sqrt().max() # TorchScript-friendly norm - cur_iter += ncv - k - - # Check for divergence. This may happen in edge cases where our _gram_schmidt - # is not run for a sufficient number of iterations. Most implementations would - # use a dynamic number of iterations, but we don't have a good way to do that - # in TorchScript, so we use a fixed number (2). - if new_res > 2 * first_res: - break - else: - res = new_res - - if verbose: - print(f"Residual: {res} after {cur_iter} iterations.") - - if res > tol: - raise ConvergenceError( - f"Failed to converge after {cur_iter} iterations. " - f"Residual: {res}, initial residual: {first_res}." - ) - - # We use the torch.autocast decorator above to speed up the algorithm, but - # make sure the returned values are in the same dtype as the input. - return Eigendecomposition(w.type_as(A), x.type_as(A)) - - -@torch.jit.script -def _solve_ritz_pairs(diag, off_diag, beta_k: Optional[Tensor], k: int, which: str): - """Solve the standard eigenvalue problem for the Ritz values and vectors. - - Args: - diag (Tensor): The diagonal elements of the tridiagonal matrix. - off_diag (Tensor): The off-diagonal elements of the tridiagonal matrix. - beta_k (Tensor, optional): ??? - k (int): The number of eigenpairs to compute. - which (str): Which k eigenvalues and eigenvectors to compute. - Must be one of 'LA', or 'SA'. - """ - # Create tri-diagonal matrix - t = diag.diag_embed() - t += off_diag[..., :-1].diag_embed(1) - t += off_diag[..., :-1].diag_embed(-1) - - if beta_k is not None: - t[..., k, :k] = beta_k - t[..., :k, k] = beta_k - - # The eigenpairs are already sorted ascending by algebraic value - w, s = torch.linalg.eigh(t) - - # Pick-up k ritz-values and ritz-vectors - if which == "LA": - wk = w[..., -k:] - sk = s[..., -k:] - elif which == "SA": - wk = w[..., :k] - sk = s[..., :k] - else: - raise ValueError("`which` must be LA or SA") - - return wk, sk - - -@torch.jit.script -def _gram_schmidt(z: Tensor, Q: Tensor, num_iter: int = 2): - """Iteratively make vector `z` orthogonal to the semi-orthogonal basis `Q`. - - Wu & Simon (1998) define "semi-orthogonal" to mean that the largest off-diagonal - element of the Gram matrix is no greater than sqrt(machine epsilon). They prove - that if `Q` is semi-orthogonal, then applying Gram-Schmidt to `z` will make it - closer to being orthogonal to the span of `Q`. Multiple iterations of this process - may be necessary to achieve orthogonality up to machine precision. See pages 12-14 - of Wu & Simon (1998) for details. - """ - for _ in range(num_iter): - proj = torch.einsum("...ij,...j->...i", Q.conj(), z) - z -= torch.einsum("...ij,...i->...j", Q, proj) - - -@torch.jit.script -def _lanczos_inner_loop(A, krylov, q, alpha, beta, k: int, end: int): - """Step 2 of Algorithm 3 in Wu & Simon (1998).""" - - for i in range(k, end): - # Compute the next matrix-vector product Au - q[:] = torch.einsum("...ij,...j->...i", A, krylov[..., i, :]) - alpha[..., i] = torch.einsum("...i,...i->...", krylov[..., i, :], q) - - # Project away from the current Krylov subspace - _gram_schmidt(q, krylov[..., : i + 1, :]) - - # Record how much is left after projection - beta[..., i] = q.square().sum(dim=-1).sqrt() # TorchScript-friendly norm - if i < end - 1: - krylov[..., i + 1, :] = q / beta[..., i, None] diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index a42942980..2999aaacb 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -19,7 +19,9 @@ # this assumption, but we never do that. @cache def select_usable_devices( - num_gpus: int = -1, *, min_memory: int | None = None + num_gpus: int = -1, + *, + min_memory: int | None = None, ) -> list[str]: """Select a set of devices that have at least `min_memory` bytes of free memory. Blocks until at least `num_gpus` devices are available. diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py deleted file mode 100644 index 6303cb039..000000000 --- a/tests/test_eigen_reporter.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -from elk.training import EigenFitter, EigenFitterConfig -from elk.utils import batch_cov, cov_mean_fused - - -def test_eigen_reporter(): - num_clusters = 5 - hidden_size = 10 - N = 100 - - x = torch.randn(N, num_clusters, 2, hidden_size, dtype=torch.float64) - x1, x2 = x.chunk(2, dim=0) - x_neg, x_pos = x.unbind(2) - - reporter = EigenFitter( - EigenFitterConfig(), - hidden_size, - dtype=torch.float64, - num_variants=num_clusters, - ) - reporter.update(x1) - reporter.update(x2) - - # Check that the streaming mean is correct - neg_mu, pos_mu = x_neg.mean(dim=(0, 1)), x_pos.mean(dim=(0, 1)) - - assert reporter.class_means is not None - torch.testing.assert_close(reporter.class_means[0], neg_mu) - torch.testing.assert_close(reporter.class_means[1], pos_mu) - - # Check that the streaming covariance is correct - neg_centroids, pos_centroids = x_neg.mean(dim=1), x_pos.mean(dim=1) - true_cov = 0.5 * (batch_cov(neg_centroids) + batch_cov(pos_centroids)) - torch.testing.assert_close(reporter.intercluster_cov, true_cov) - - # Check that the streaming negative covariance is correct - true_xcov = (neg_centroids - neg_mu).mT @ (pos_centroids - pos_mu) / N - true_xcov = 0.5 * (true_xcov + true_xcov.mT) - torch.testing.assert_close(reporter.contrastive_xcov, true_xcov) - - # Check that the streaming invariance (intra-cluster variance) is correct. - # This is actually the same whether or not we track class means. - expected_invariance = 0.5 * (cov_mean_fused(x_neg) + cov_mean_fused(x_pos)) - torch.testing.assert_close(reporter.intracluster_cov, expected_invariance) - - assert reporter.n == N diff --git a/tests/test_encodings.py b/tests/test_encodings.py new file mode 100644 index 000000000..790c00658 --- /dev/null +++ b/tests/test_encodings.py @@ -0,0 +1,55 @@ +from datasets import load_dataset +from transformers import AutoTokenizer + +from elk.extraction import Extract, tokenize_dataset + + +def test_get_encodings(): + dataset_name = "imdb" + model_path = "sshleifer/tiny-gpt2" + + seed = 42 + cfg = Extract( + model=model_path, + datasets=(dataset_name,), + max_examples=(10, 10), + template_path="_default", + get_lm_preds=True, + statement_column="text", + balance=False, + seed=seed, + ) + split_type = "train" + encodings = tokenize_dataset(cfg, split_type) + + tokenizer = AutoTokenizer.from_pretrained(model_path, truncation_side="left") + ds = load_dataset(dataset_name, split=split_type) + ds = ds.add_column("row_id", range(len(ds))) # type: ignore + ds = ds.shuffle(seed=seed).select(range(10)) # type: ignore + + suffix = '\n\n\nQ: Is the above statement "True" or "False"?\n\nA:' + suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False) + + def map_fn(ex: dict) -> dict: + out_record = { + "row_id": ex["row_id"], + "label": ex["label"], + "variant_id": "_default", + "text": ex["text"] + suffix, + "num_suffix_tokens": len(suffix_tokens), + } + input_ids = tokenizer(ex["text"], add_special_tokens=True)["input_ids"] + out_record["input_ids"] = [input_ids + suffix_tokens] # type: ignore + answer_ids = [ + tokenizer.encode(s, add_special_tokens=False)[0] for s in ["False", "True"] + ] + out_record["answer_ids"] = answer_ids + return out_record + + ds = ds.map(map_fn, batched=False, remove_columns=ds.column_names, num_proc=1) + gt_ds = ds.filter(lambda ex: len(ex["input_ids"]) <= tokenizer.model_max_length) + + assert len(encodings) == len(gt_ds) + assert set(encodings.column_names) == set(gt_ds.column_names) + for col in encodings.column_names: + assert encodings[col] == gt_ds[col] diff --git a/tests/test_inference_server.py b/tests/test_inference_server.py new file mode 100644 index 000000000..b02acdfef --- /dev/null +++ b/tests/test_inference_server.py @@ -0,0 +1,53 @@ +import pytest +import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from elk.extraction.inference_server import InferenceServer + + +@pytest.mark.gpu +@pytest.mark.filterwarnings("ignore:Unable to find a decoding function") +def test_inference_server(): + model_name = "EleutherAI/pythia-70m" + tokenizer = AutoTokenizer.from_pretrained(model_name) + ds = Dataset.from_dict({"text": ["Lorem", "ipsum", "dolor", "sit", "amet"]}) + + ds = ds.map( + lambda x: tokenizer(x["text"], padding="max_length", truncation=True), + batched=True, + remove_columns=["text"], + ) + + device = "cuda" + model = AutoModelForCausalLM.from_pretrained(model_name).to(device) + with ds.formatted_as("torch"): + + def gt_inference(ex: dict): + with torch.no_grad(): + ex = {k: v.to(device).unsqueeze(0) for k, v in ex.items()} + return {"out": model(**ex)} + + gt_out_ds = ds.map(gt_inference, batched=False, remove_columns=ds.column_names) + gt_outs = gt_out_ds["out"] + + def test_config(fsdp: bool, num_workers: int, cpu_offload: bool = True): + with InferenceServer( + model_str=model_name, + fsdp=fsdp, + num_workers=num_workers, + cpu_offload=cpu_offload, + ) as server: + outs = server.map_forward(ds) + assert len(outs) == len(gt_outs) + for out, gt_out in zip(outs, gt_outs): + out_logits = out["logits"] + assert torch.allclose(out_logits, gt_out["logits"].cpu()) + + test_config(fsdp=False, num_workers=-1) + test_config(fsdp=False, num_workers=1) + test_config(fsdp=False, num_workers=2) + test_config(fsdp=True, num_workers=-1, cpu_offload=False) + test_config(fsdp=True, num_workers=-1) + test_config(fsdp=True, num_workers=1) + test_config(fsdp=True, num_workers=2) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 9a1f694e2..44d8bad52 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -14,7 +14,7 @@ def test_output_batches_are_balanced(): IterableDataset, load_dataset("super_glue", "boolq", split="train", streaming=True), ) - label_col = infer_label_column(dataset.features) + label_col = infer_label_column(dataset.features) # type: ignore # Start with an even number of shots; make sure they're exactly balanced sampler = FewShotSampler(dataset, 6, rng=Random(42)) @@ -40,7 +40,7 @@ def test_output_is_roughly_balanced(): load_dataset("super_glue", "boolq", split="train", streaming=True), ) - col = infer_label_column(dataset.features) + col = infer_label_column(dataset.features) # type: ignore reservoir = BalancedSampler(dataset, {0, 1}) # Count the number of samples for each label diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index bac0f3989..b22913b25 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -1,56 +1,25 @@ from pathlib import Path +import pytest + from elk import Extract -from elk.training import CcsConfig, EigenFitterConfig from elk.training.train import Elicit -def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): - # we need about 5 mb of gpu memory to run this test - model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 - dataset_name = "imdb" - elicit = Elicit( - data=Extract( - model=model_path, - datasets=(dataset_name,), - max_examples=(10, 10), - # run on all layers, tiny-gpt only has 2 layers - ), - num_gpus=2, - min_gpu_mem=min_mem, - net=CcsConfig(), - out_dir=tmp_path, - ) - elicit.execute() - # get the files in the tmp_path - files: list[Path] = list(tmp_path.iterdir()) - created_file_names = {file.name for file in files} - expected_files = [ - "cfg.yaml", - "fingerprints.yaml", - "lr_models", - "reporters", - "eval.csv", - ] - for file in expected_files: - assert file in created_file_names - - -def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): +@pytest.mark.gpu +def test_smoke_elicit_run_tiny_gpt2(tmp_path: Path): # we need about 5 mb of gpu memory to run this test - model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 + model_path = "sshleifer/tiny-gpt2" dataset_name = "imdb" elicit = Elicit( data=Extract( model=model_path, datasets=(dataset_name,), max_examples=(10, 10), - # run on all layers, tiny-gpt only has 2 layers ), num_gpus=2, - min_gpu_mem=min_mem, - net=EigenFitterConfig(), out_dir=tmp_path, + min_gpu_mem=5_000_000, ) elicit.execute() # get the files in the tmp_path @@ -60,8 +29,6 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): "cfg.yaml", "fingerprints.yaml", "lr_models", - "reporters", - "eval.csv", ] for file in expected_files: assert file in created_file_names diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index 4efd7112d..40cd17fa7 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -1,16 +1,19 @@ from pathlib import Path +import numpy as np import pandas as pd +import pytest +import torch +from sklearn.metrics import roc_auc_score from elk import Extract from elk.evaluation import Eval -from elk.training import CcsConfig, EigenFitterConfig from elk.training.train import Elicit EVAL_EXPECTED_FILES = [ "cfg.yaml", "fingerprints.yaml", - "eval.csv", + "lr_eval.csv", ] @@ -19,8 +22,6 @@ def setup_elicit( tmp_path: Path, dataset_name="imdb", model_path="sshleifer/tiny-gpt2", - min_mem=10 * 1024**2, - is_ccs: bool = True, ) -> Elicit: """Setup elicit config for testing, execute elicit, and save output to tmp_path. Returns the elicit run configuration. @@ -30,12 +31,10 @@ def setup_elicit( model=model_path, datasets=(dataset_name,), max_examples=(10, 10), - # run on all layers, tiny-gpt only has 2 layers ), - num_gpus=2, - min_gpu_mem=min_mem, - net=CcsConfig() if is_ccs else EigenFitterConfig(), + num_gpus=1, out_dir=tmp_path, + min_gpu_mem=5_000_000, ) elicit.execute() return elicit @@ -58,18 +57,18 @@ def eval_run(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()) -> float: assert tmp_path is not None # record elicit modification time as reference. - start_time_sec = (tmp_path / "eval.csv").stat().st_mtime + start_time_sec = (tmp_path / "lr_eval.csv").stat().st_mtime if transfer_datasets: # update datasets to a different dataset extract.datasets = transfer_datasets - eval = Eval(data=extract, source=tmp_path) + eval = Eval(data=extract, source=tmp_path, save_logprobs=True) eval.execute() return start_time_sec -def eval_assert_files_created(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()): +def eval_assert_files_good(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()): tmp_path = elicit.out_dir assert tmp_path is not None @@ -77,34 +76,46 @@ def eval_assert_files_created(elicit: Elicit, transfer_datasets: tuple[str, ...] assert eval_dir.exists(), f"transfer eval dir {eval_dir} does not exist" check_contains_files(eval_dir, EVAL_EXPECTED_FILES) # read "eval.csv" into a df - df = pd.read_csv(eval_dir / "eval.csv") + df = pd.read_csv(eval_dir / "lr_eval.csv") # get the "dataset" column dataset_col = df["dataset"] + logprobs_dict = torch.load(eval_dir / "logprobs.pt") + for tfr_dataset in transfer_datasets: # assert that the dataset column contains the transfer dataset assert tfr_dataset in dataset_col.values + assert tfr_dataset in logprobs_dict -"""TESTS""" + # make sure that auroc computed from logprobs matches the auroc in lr_eval.csv + logprobs = logprobs_dict[tfr_dataset] + for layer in df["layer"].unique(): + mode = "full" + current_df = df[ + (df["dataset"] == tfr_dataset) + & (df["layer"] == layer) + & (df["ensembling"] == mode) + & (df["inlp_iter"] == 0) + ] + assert len(current_df) == 1 + eval_auroc = current_df["auroc_estimate"].iloc[0] + # get the logprobs for the current layer and mode + lr_logprobs = logprobs["lr"][layer][mode][0] + labels = logprobs["labels"] -def test_smoke_tfr_eval_run_tiny_gpt2_ccs(tmp_path: Path): - elicit = setup_elicit(tmp_path) - transfer_datasets = ("christykoh/imdb_pt",) - eval_run(elicit, transfer_datasets=transfer_datasets) - eval_assert_files_created(elicit, transfer_datasets=transfer_datasets) + auroc = roc_auc_score(labels, lr_logprobs) + np.testing.assert_almost_equal(auroc, eval_auroc) -def test_smoke_eval_run_tiny_gpt2_eigen(tmp_path: Path): - elicit = setup_elicit(tmp_path, is_ccs=False) - transfer_datasets = ("christykoh/imdb_pt",) - eval_run(elicit, transfer_datasets=transfer_datasets) - eval_assert_files_created(elicit, transfer_datasets=transfer_datasets) +"""TESTS""" -def test_smoke_multi_eval_run_tiny_gpt2_ccs(tmp_path: Path): + +@pytest.mark.gpu +def test_smoke_eval_run_tiny_gpt2(tmp_path: Path): elicit = setup_elicit(tmp_path) - transfer_datasets = ("christykoh/imdb_pt", "super_glue:boolq") + transfer_datasets = ("christykoh/imdb_pt",) eval_run(elicit, transfer_datasets=transfer_datasets) - eval_assert_files_created(elicit, transfer_datasets=transfer_datasets) + eval_assert_files_good(elicit, transfer_datasets=transfer_datasets) diff --git a/tests/test_truncated_eigh.py b/tests/test_truncated_eigh.py deleted file mode 100644 index 5241f1c0d..000000000 --- a/tests/test_truncated_eigh.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np -import pytest -import torch -from scipy.sparse.linalg import eigsh - -from elk.truncated_eigh import truncated_eigh - - -def random_symmetric_matrix(n: int, k: int) -> torch.Tensor: - """Random symmetric matrix with `k` nonzero eigenvalues centered around zero.""" - assert k <= n, "Rank k should be less than or equal to the matrix size n." - - # Generate random n x k matrix A with elements drawn from a uniform distribution - A = torch.rand(n, k) / k**0.5 - - # Create a diagonal matrix D with k eigenvalues evenly distributed around zero - eigenvalues = torch.linspace(-1, 1, k) - D = torch.diag(eigenvalues) - - # Compute the product A * D * A.T to obtain a symmetric matrix with the desired - # eigenvalue distribution - symm_matrix = A @ D @ A.T - - return symm_matrix - - -@pytest.mark.parametrize("n", [32, 768, 6144]) -@pytest.mark.parametrize("full_rank", [False, True]) -@pytest.mark.parametrize("which", ["LA", "SA"]) -def test_truncated_eigh(n: int, full_rank: bool, which): - torch.manual_seed(42) - - if full_rank: - A = torch.randn(n, n) - else: - # Generate a random symmetric matrix with rank n // 2 - A = random_symmetric_matrix(n, n // 2) - - A = A + A.T - - # Compute the top k eigenpairs using our implementation - w, v = truncated_eigh(A, k=6, which=which, tol=1e-5) - - # Compute the top k eigenpairs using scipy - w_scipy, v_scipy = eigsh(A.numpy(), which=which) - - # Check that the eigenvalues match to within the tolerance - torch.testing.assert_close(w, torch.from_numpy(w_scipy), atol=1e-3, rtol=1e-3) - - # Normalize the sign of the eigenvectors - for i in range(v.shape[-1]): - if v[torch.argmax(torch.abs(v[:, i])), i] < 0: - v[:, i] *= -1 - if v_scipy[np.argmax(np.abs(v_scipy[:, i])), i] < 0: - v_scipy[:, i] *= -1 - - # Check that the eigenvectors match to within the tolerance - torch.testing.assert_close(v, torch.from_numpy(v_scipy), atol=1e-3, rtol=1e-3) diff --git a/tests/test_viz.py b/tests/test_viz.py deleted file mode 100644 index fe3214b01..000000000 --- a/tests/test_viz.py +++ /dev/null @@ -1,31 +0,0 @@ -from pathlib import Path - -import pytest - -from elk.plotting.visualize import SweepVisualization - - -@pytest.fixture -def setup_fs(fs): - test_dir = "/sweep1" - fs.create_dir(test_dir) - fs.create_dir(f"{test_dir}/huggyllama/llama-13b/imdb") - fs.create_file(f"{test_dir}/huggyllama/llama-13b/imdb/eval.csv") - fs.create_dir(f"{test_dir}/huggyllama/llama-12b/news") - fs.create_file(f"{test_dir}/huggyllama/llama-12b/news/eval.csv") - fs.create_file(f"{test_dir}/gpt2-medium/imdb/eval.csv") - - return Path(test_dir) - - -def test_get_model_paths(setup_fs): - test_dir = setup_fs - result = SweepVisualization._get_model_paths(test_dir) - - root = Path(test_dir) - for path in root.rglob("*"): - print(path) - assert len(result) == 3 - assert any([p.name == "llama-13b" for p in result]) - assert any([p.name == "llama-12b" for p in result]) - assert any([p.name == "gpt2-medium" for p in result])