Skip to content

Commit

Permalink
refactor reporter training
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Jul 13, 2023
1 parent 47bcfb2 commit a50fe57
Showing 1 changed file with 65 additions and 23 deletions.
88 changes: 65 additions & 23 deletions elk/training/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Main training loop."""

from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Literal

Expand All @@ -11,15 +11,22 @@
from simple_parsing import subgroups
from simple_parsing.helpers.serialization import save

from ..evaluation import Eval
from ..metrics import evaluate_preds, to_one_hot
from ..run import Run
from ..training.supervised import train_supervised
from ..utils.typing import assert_type
from .ccs_reporter import CcsConfig, CcsReporter
from .common import FitterConfig
from .common import FitterConfig, Reporter
from .eigen_reporter import EigenFitter, EigenFitterConfig


@dataclass
class ReporterTrainResult:
reporter: CcsReporter | Reporter
train_loss: float | None


@dataclass
class Elicit(Run):
"""Full specification of a reporter training run."""
Expand Down Expand Up @@ -48,22 +55,31 @@ def create_models_dir(self, out_dir: Path):

return reporter_dir, lr_dir

def apply_to_layer(
self,
layer: int,
devices: list[str],
world_size: int,
probe_per_prompt: bool,
) -> dict[str, pd.DataFrame]:
"""Train a single reporter on a single layer."""

self.make_reproducible(seed=self.net.seed + layer)
device = self.get_device(devices, world_size)

def make_eval(self, model, eval_dataset):
assert self.out_dir is not None
return Eval(
data=replace(
self.data,
model=model,
datasets=(eval_dataset,),
),
source=self.out_dir,
out_dir=self.out_dir / "transfer" / eval_dataset,
num_gpus=self.num_gpus,
min_gpu_mem=self.min_gpu_mem,
skip_supervised=self.supervised == "none",
prompt_indices=self.prompt_indices,
concatenated_layer_offset=self.concatenated_layer_offset,
# datasets isn't needed because it's immediately overwritten
debug=self.debug,
disable_cache=self.disable_cache,
)

# Create a separate function to handle the reporter training.
def train_reporter(self, device, layer, out_dir) -> ReporterTrainResult:
train_dict = self.prepare_data(device, layer, "train")
val_dict = self.prepare_data(device, layer, "val")

(first_train_h, train_gt, _), *rest = train_dict.values()
(first_train_h, train_gt, _), *rest = train_dict.values() # TODO can remove?
(_, v, k, d) = first_train_h.shape
if not all(other_h.shape[-1] == d for other_h, _, _ in rest):
raise ValueError("All datasets must have the same hidden state size")
Expand All @@ -75,16 +91,12 @@ def apply_to_layer(
if not all(other_h.shape[-2] == k for other_h, _, _ in rest):
raise ValueError("All datasets must have the same number of classes")

reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
train_loss = None

if isinstance(self.net, CcsConfig):
assert len(train_dict) == 1, "CCS only supports single-task training"

reporter = CcsReporter(self.net, d, device=device, num_variants=v)
train_loss = reporter.fit(first_train_h)

(_, v, k, _) = first_train_h.shape
reporter.platt_scale(
to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(),
rearrange(first_train_h, "n v k d -> (n v k) d"),
Expand Down Expand Up @@ -116,20 +128,50 @@ def apply_to_layer(
raise ValueError(f"Unknown reporter config type: {type(self.net)}")

# Save reporter checkpoint to disk
torch.save(reporter, reporter_dir / f"layer_{layer}.pt")
torch.save(reporter, out_dir / f"layer_{layer}.pt")

# Fit supervised logistic regression model
return ReporterTrainResult(reporter, train_loss)

def train_lr_model(self, train_dict, device, layer, out_dir):
if self.supervised != "none":
lr_models = train_supervised(
train_dict,
device=device,
mode=self.supervised,
)
with open(lr_dir / f"layer_{layer}.pt", "wb") as file:
with open(out_dir / f"layer_{layer}.pt", "wb") as file:
torch.save(lr_models, file)
else:
lr_models = []

return lr_models

def apply_to_layer(
self,
layer: int,
devices: list[str],
world_size: int,
probe_per_prompt: bool,
) -> dict[str, pd.DataFrame]:
"""Train a single reporter on a single layer."""

self.make_reproducible(seed=self.net.seed + layer)
device = self.get_device(devices, world_size)

train_dict = self.prepare_data(device, layer, "train")
val_dict = self.prepare_data(device, layer, "val")

(first_train_h, train_gt, _), *rest = train_dict.values()
(_, v, k, d) = first_train_h.shape

reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))

reporter_train_result = self.train_reporter(device, layer, reporter_dir)
reporter = reporter_train_result.reporter
train_loss = reporter_train_result.train_loss

lr_models = self.train_lr_model(train_dict, device, layer, lr_dir)

row_bufs = defaultdict(list)
for ds_name in val_dict:
val_h, val_gt, val_lm_preds = val_dict[ds_name]
Expand Down

0 comments on commit a50fe57

Please sign in to comment.