diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4cfbdb86..fc333d7d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,8 +14,8 @@ repos: rev: 23.3.0 hooks: - id: black -- repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.0.276' +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: 'v0.0.277' hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index d6054e33..8462cc00 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -30,7 +30,7 @@ def execute(self, highlight_color: Color = "cyan"): @torch.inference_mode() def apply_to_layer( - self, layer: int, devices: list[str], world_size: int + self, layer: int, devices: list[str], world_size: int, probe_per_prompt: bool ) -> dict[str, pd.DataFrame]: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index d4e6f4ea..2975823d 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -192,6 +192,7 @@ def extract_hiddens( template_path=cfg.template_path, rank=rank, world_size=world_size, + seed=cfg.seed, ) # Add one to the number of layers to account for the embedding layer diff --git a/elk/run.py b/elk/run.py index fb8903cc..b444d80b 100644 --- a/elk/run.py +++ b/elk/run.py @@ -46,6 +46,10 @@ class Run(ABC, Serializable): prompt_indices: tuple[int, ...] = () """The indices of the prompt templates to use. If empty, all prompts are used.""" + probe_per_prompt: bool = False + """If true, a probe is trained per prompt template. Otherwise, a single probe is + trained for all prompt templates.""" + concatenated_layer_offset: int = 0 debug: bool = False min_gpu_mem: int | None = None # in bytes @@ -99,13 +103,16 @@ def execute( devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem) num_devices = len(devices) func: Callable[[int], dict[str, pd.DataFrame]] = partial( - self.apply_to_layer, devices=devices, world_size=num_devices + self.apply_to_layer, + devices=devices, + world_size=num_devices, + probe_per_prompt=self.probe_per_prompt, ) self.apply_to_layers(func=func, num_devices=num_devices) @abstractmethod def apply_to_layer( - self, layer: int, devices: list[str], world_size: int + self, layer: int, devices: list[str], world_size: int, probe_per_prompt: bool ) -> dict[str, pd.DataFrame]: """Train or eval a reporter on a single layer.""" @@ -180,13 +187,19 @@ def apply_to_layers( df_buffers = defaultdict(list) try: - for df_dict in tqdm(mapper(func, layers), total=len(layers)): - for k, v in df_dict.items(): - df_buffers[k].append(v) + for df_dicts in tqdm(mapper(func, layers), total=len(layers)): + for df_dict in df_dicts: + for k, v in df_dict.items(): + df_buffers[k].append(v) finally: # Make sure the CSVs are written even if we crash or get interrupted for name, dfs in df_buffers.items(): - df = pd.concat(dfs).sort_values(by=["layer", "ensembling"]) - df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False) + sortby = ["layer", "ensembling"] + if "prompt_index" in dfs[0].columns: + sortby.append("prompt_index") + # TODO make the prompt index third col + df = pd.concat(dfs).sort_values(by=sortby) + out_path = self.out_dir / f"{name}.csv" + df.round(4).to_csv(out_path, index=False) if self.debug: save_debug_log(self.datasets, self.out_dir) diff --git a/elk/training/train.py b/elk/training/train.py index 8392f2d9..2d85cc13 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,7 +1,7 @@ """Main training loop.""" from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, replace from pathlib import Path from typing import Literal @@ -11,15 +11,21 @@ from simple_parsing import subgroups from simple_parsing.helpers.serialization import save +from ..evaluation import Eval from ..metrics import evaluate_preds, to_one_hot from ..run import Run from ..training.supervised import train_supervised -from ..utils.typing import assert_type from .ccs_reporter import CcsConfig, CcsReporter -from .common import FitterConfig +from .common import FitterConfig, Reporter from .eigen_reporter import EigenFitter, EigenFitterConfig +@dataclass +class ReporterTrainResult: + reporter: CcsReporter | Reporter + train_loss: float | None + + @dataclass class Elicit(Run): """Full specification of a reporter training run.""" @@ -34,6 +40,81 @@ class Elicit(Run): cross-validation. Defaults to "single", which means to train a single classifier on the training data. "cv" means to use cross-validation.""" + def evaluate_and_save( + self, + train_loss, + reporter, + train_dict, + val_dict, + lr_models, + layer, + prompt_index=None, + ): + row_bufs = defaultdict(list) + for ds_name in val_dict: + val_h, val_gt, val_lm_preds = val_dict[ds_name] + train_h, train_gt, train_lm_preds = train_dict[ds_name] + meta = {"dataset": ds_name, "layer": layer} + + val_credences = reporter(val_h) + train_credences = reporter(train_h) + maybe_prompt_index = ( + {} if prompt_index is None else {"prompt_index": prompt_index} + ) + for mode in ("none", "partial", "full"): + row_bufs["eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(val_gt, val_credences, mode).to_dict(), + "train_loss": train_loss, + **maybe_prompt_index, + } + ) + + row_bufs["train_eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(train_gt, train_credences, mode).to_dict(), + "train_loss": train_loss, + **maybe_prompt_index, + } + ) + + if val_lm_preds is not None: + row_bufs["lm_eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), + **maybe_prompt_index, + } + ) + + if train_lm_preds is not None: + row_bufs["train_lm_eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(train_gt, train_lm_preds, mode).to_dict(), + **maybe_prompt_index, + } + ) + + for i, model in enumerate(lr_models): + row_bufs["lr_eval"].append( + { + **meta, + "ensembling": mode, + "inlp_iter": i, + **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + **maybe_prompt_index, + } + ) + + return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + def create_models_dir(self, out_dir: Path): lr_dir = None lr_dir = out_dir / "lr_models" @@ -48,21 +129,31 @@ def create_models_dir(self, out_dir: Path): return reporter_dir, lr_dir - def apply_to_layer( - self, - layer: int, - devices: list[str], - world_size: int, - ) -> dict[str, pd.DataFrame]: - """Train a single reporter on a single layer.""" - - self.make_reproducible(seed=self.net.seed + layer) - device = self.get_device(devices, world_size) - - train_dict = self.prepare_data(device, layer, "train") - val_dict = self.prepare_data(device, layer, "val") - - (first_train_h, train_gt, _), *rest = train_dict.values() + def make_eval(self, model, eval_dataset): + assert self.out_dir is not None + return Eval( + data=replace( + self.data, + model=model, + datasets=(eval_dataset,), + ), + source=self.out_dir, + out_dir=self.out_dir / "transfer" / eval_dataset, + num_gpus=self.num_gpus, + min_gpu_mem=self.min_gpu_mem, + skip_supervised=self.supervised == "none", + prompt_indices=self.prompt_indices, + concatenated_layer_offset=self.concatenated_layer_offset, + # datasets isn't needed because it's immediately overwritten + debug=self.debug, + disable_cache=self.disable_cache, + ) + + # Create a separate function to handle the reporter training. + def train_and_save_reporter( + self, device, layer, out_dir, train_dict + ) -> ReporterTrainResult: + (first_train_h, train_gt, _), *rest = train_dict.values() # TODO can remove? (_, v, k, d) = first_train_h.shape if not all(other_h.shape[-1] == d for other_h, _, _ in rest): raise ValueError("All datasets must have the same hidden state size") @@ -74,16 +165,12 @@ def apply_to_layer( if not all(other_h.shape[-2] == k for other_h, _, _ in rest): raise ValueError("All datasets must have the same number of classes") - reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) train_loss = None - if isinstance(self.net, CcsConfig): assert len(train_dict) == 1, "CCS only supports single-task training" - reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - (_, v, k, _) = first_train_h.shape reporter.platt_scale( to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), rearrange(first_train_h, "n v k d -> (n v k) d"), @@ -115,73 +202,103 @@ def apply_to_layer( raise ValueError(f"Unknown reporter config type: {type(self.net)}") # Save reporter checkpoint to disk - torch.save(reporter, reporter_dir / f"layer_{layer}.pt") + # TODO have to change this + out_dir.mkdir(parents=True, exist_ok=True) + torch.save(reporter, out_dir / f"layer_{layer}.pt") - # Fit supervised logistic regression model + return ReporterTrainResult(reporter, train_loss) + + def train_lr_model(self, train_dict, device, layer, out_dir): if self.supervised != "none": lr_models = train_supervised( train_dict, device=device, mode=self.supervised, ) - with open(lr_dir / f"layer_{layer}.pt", "wb") as file: + # make dir if not exists + out_dir.mkdir(parents=True, exist_ok=True) + with open(out_dir / f"layer_{layer}.pt", "wb") as file: torch.save(lr_models, file) else: lr_models = [] - row_bufs = defaultdict(list) - for ds_name in val_dict: - val_h, val_gt, val_lm_preds = val_dict[ds_name] - train_h, train_gt, train_lm_preds = train_dict[ds_name] - meta = {"dataset": ds_name, "layer": layer} + return lr_models - val_credences = reporter(val_h) - train_credences = reporter(train_h) - for mode in ("none", "partial", "full"): - row_bufs["eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_credences, mode).to_dict(), - "train_loss": train_loss, - } - ) + def apply_to_layer( + self, + layer: int, + devices: list[str], + world_size: int, + probe_per_prompt: bool, + ) -> list[dict[str, pd.DataFrame]]: + """Train a single reporter on a single layer.""" - row_bufs["train_eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(train_gt, train_credences, mode).to_dict(), - "train_loss": train_loss, - } - ) + self.make_reproducible(seed=self.net.seed + layer) + device = self.get_device(devices, world_size) - if val_lm_preds is not None: - row_bufs["lm_eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), - } - ) + train_dict = self.prepare_data(device, layer, "train") + val_dict = self.prepare_data(device, layer, "val") - if train_lm_preds is not None: - row_bufs["train_lm_eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(train_gt, train_lm_preds, mode).to_dict(), - } + (first_train_h, train_gt, _), *rest = train_dict.values() + (_, v, k, d) = first_train_h.shape + + # TODO is this even needed + # reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) + + probe_per_prompt = True + if probe_per_prompt: + train_dicts = [ + { + ds_name: ( + train_h[:, i : i + 1, ...], + train_gt, + lm_preds[:, i : i + 1, ...] if lm_preds is not None else None, ) + } + for ds_name, (train_h, _, lm_preds) in train_dict.items() + for i in range(v) # v is number of variants + ] + + res = [] + for i, train_dict in enumerate(train_dicts): + reporters_path = self.out_dir / str(i) / "reporters" + lr_path = self.out_dir / str(i) / "lr_models" + + reporter_train_result = self.train_and_save_reporter( + device, layer, reporters_path, train_dict + ) - for i, model in enumerate(lr_models): - row_bufs["lr_eval"].append( - { - **meta, - "ensembling": mode, - "inlp_iter": i, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), - } + reporter = reporter_train_result.reporter + train_loss = reporter_train_result.train_loss + + lr_models = self.train_lr_model(train_dict, device, layer, lr_path) + + res.append( + self.evaluate_and_save( + train_loss, + reporter, + train_dict, + val_dict, + lr_models, + layer, + prompt_index=i, ) + ) + return res + else: + reporter_train_result = self.train_and_save_reporter( + device, layer, self.out_dir / "reporters", train_dict + ) - return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + reporter = reporter_train_result.reporter + train_loss = reporter_train_result.train_loss + + lr_models = self.train_lr_model( + train_dict, device, layer, self.out_dir / "lr_models" + ) + + return [ + self.evaluate_and_save( + train_loss, reporter, train_dict, val_dict, lr_models, layer + ) + ]