Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train probe per prompt #271

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
resolved circular import
  • Loading branch information
derpyplops committed Jul 19, 2023
commit 96a3dabc783d0a3f803bdadaebc59deb880148b6
94 changes: 63 additions & 31 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import pandas as pd
import torch
Expand All @@ -9,6 +10,7 @@
from ..files import elk_reporter_dir
from ..metrics import evaluate_preds
from ..run import Run
from ..training.multi_reporter import AnyReporter, MultiReporter
from ..utils import Color


Expand Down Expand Up @@ -38,39 +40,69 @@

experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter = torch.load(reporter_path, map_location=device)
def load_reporter() -> AnyReporter | MultiReporter:
# check if experiment_dir / "reporters" has .pt files
first = next((experiment_dir / "reporters").iterdir())
if not first.suffix == ".pt":
return MultiReporter.load(
experiment_dir / "reporters", layer, device=device
)
else:
path = experiment_dir / "reporters" / f"layer_{layer}.pt"
return torch.load(path, map_location=device)

reporter = load_reporter()

row_bufs = defaultdict(list)
for ds_name, (val_h, val_gt, _) in val_output.items():
meta = {"dataset": ds_name, "layer": layer}

val_credences = reporter(val_h)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
}
)

lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_models = torch.load(f, map_location=device)
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]

for i, model in enumerate(lr_models):
model.eval()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"inlp_iter": i,
**meta,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
}
)
def eval_all(
reporter: AnyReporter | MultiReporter,
prompt_index: int | Literal["multi"] | None = None,
):
prompt_index = (

Check failure on line 62 in elk/evaluation/evaluate.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Expression of type "dict[str, int | str]" cannot be assigned to declared type "int | Literal['multi'] | None"   Type "dict[str, int | str]" cannot be assigned to type "int | Literal['multi'] | None"     "dict[str, int | str]" is incompatible with "int"     Type cannot be assigned to type "None"     "dict[str, int | str]" cannot be assigned to type "Literal['multi']" (reportGeneralTypeIssues)
{"prompt_index": prompt_index} if prompt_index is not None else {}
)
for ds_name, (val_h, val_gt, _) in val_output.items():
meta = {"dataset": ds_name, "layer": layer}

val_credences = reporter(val_h)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
**prompt_index,

Check failure on line 75 in elk/evaluation/evaluate.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_models = torch.load(f, map_location=device)
if not isinstance(
lr_models, list
): # backward compatibility
lr_models = [lr_models]

for i, model in enumerate(lr_models):
model.eval()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"inlp_iter": i,
**meta,
**evaluate_preds(
val_gt, model(val_h), mode
).to_dict(),
}
)

if isinstance(reporter, MultiReporter):
for prompt_index, single_reporter in enumerate(reporter.reporters):
eval_all(single_reporter, prompt_index)
eval_all(reporter, "multi")
else:
eval_all(reporter)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
44 changes: 44 additions & 0 deletions elk/training/multi_reporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from dataclasses import dataclass
from pathlib import Path

import torch as t

from elk.training import CcsReporter
from elk.training.common import Reporter

AnyReporter = CcsReporter | Reporter


@dataclass
class ReporterTrainResult:
reporter: AnyReporter
train_loss: float | None


class MultiReporter:
def __init__(self, reporter_results: list[ReporterTrainResult]):
self.reporter_results: list[ReporterTrainResult] = reporter_results
self.reporters = [r.reporter for r in reporter_results]
train_losses = (
[r.train_loss for r in reporter_results]
if reporter_results[0].train_loss is not None
else None
)
self.train_loss = (
sum(train_losses) / len(train_losses) if train_losses is not None else None

Check failure on line 28 in elk/training/multi_reporter.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Argument of type "list[float | None]" cannot be assigned to parameter "__iterable" of type "Iterable[_SupportsSumNoDefaultT@sum]" in function "sum"   "list[float | None]" is incompatible with "Iterable[_SupportsSumNoDefaultT@sum]"     TypeVar "_T_co@Iterable" is covariant       Type "float | None" cannot be assigned to type "_SupportsSumWithNoDefaultGiven"         Type "float | None" cannot be assigned to type "_SupportsSumWithNoDefaultGiven"           "__add__" is not present           "__radd__" is not present (reportGeneralTypeIssues)
)

def __call__(self, h):
credences = [r(h) for r in self.reporters]
return t.stack(credences).mean(dim=0)

@staticmethod
def load(path: Path, layer: int, device: str):
prompt_folders = [p for p in path.iterdir() if p.is_dir()]
reporters = []
for folder in prompt_folders:
path = folder / "reporters" / f"layer_{layer}.pt"
reporter = t.load(path, map_location=device)
reporters.append(reporter)
# TODO for now I don't care about the train losses
return MultiReporter([ReporterTrainResult(r, None) for r in reporters])
29 changes: 2 additions & 27 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,9 @@
from ..training.supervised import train_supervised
from . import Classifier
from .ccs_reporter import CcsConfig, CcsReporter
from .common import FitterConfig, Reporter
from .common import FitterConfig
from .eigen_reporter import EigenFitter, EigenFitterConfig

AnyReporter = CcsReporter | Reporter


@dataclass
class ReporterTrainResult:
reporter: AnyReporter
train_loss: float | None


class MultiReporter:
def __init__(self, reporter_results: list[ReporterTrainResult]):
self.reporter_results: list[ReporterTrainResult] = reporter_results
self.reporters = [r.reporter for r in reporter_results]
train_losses = (
[r.train_loss for r in reporter_results]
if reporter_results[0].train_loss is not None
else None
)
self.train_loss = (
sum(train_losses) / len(train_losses) if train_losses is not None else None
)

def __call__(self, h):
credences = [r(h) for r in self.reporters]
return torch.stack(credences).mean(dim=0)
from .multi_reporter import AnyReporter, MultiReporter, ReporterTrainResult


def evaluate_and_save(
Expand All @@ -67,7 +42,7 @@
):
val_credences = reporter(val_h)
train_credences = reporter(train_h)
prompt_index = {"prompt_index": prompt_index}

Check failure on line 45 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Expression of type "dict[str, int | str]" cannot be assigned to declared type "int | Literal['multi']"   Type "dict[str, int | str]" cannot be assigned to type "int | Literal['multi']"     "dict[str, int | str]" is incompatible with "int"     "dict[str, int | str]" cannot be assigned to type "Literal['multi']" (reportGeneralTypeIssues)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
Expand All @@ -75,7 +50,7 @@
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
"train_loss": train_loss,
**prompt_index,

Check failure on line 53 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

Expand All @@ -85,7 +60,7 @@
"ensembling": mode,
**evaluate_preds(train_gt, train_credences, mode).to_dict(),
"train_loss": train_loss,
**prompt_index,

Check failure on line 63 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

Expand All @@ -95,7 +70,7 @@
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_lm_preds, mode).to_dict(),
**prompt_index,

Check failure on line 73 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

Expand All @@ -105,7 +80,7 @@
**meta,
"ensembling": mode,
**evaluate_preds(train_gt, train_lm_preds, mode).to_dict(),
**prompt_index,

Check failure on line 83 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

Expand All @@ -116,7 +91,7 @@
"ensembling": mode,
"inlp_iter": i,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
**prompt_index,

Check failure on line 94 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

Expand Down Expand Up @@ -322,5 +297,5 @@
)

return evaluate_and_save(
train_loss, maybe_multi_reporter, train_dict, val_dict, lr_models, layer

Check failure on line 300 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

"lr_models" is possibly unbound (reportUnboundVariable)
)
Loading