Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train probe per prompt #271

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
implemented multi probe for elicit
  • Loading branch information
derpyplops committed Jul 18, 2023
commit 898c3f1c7a6094b71e5166ce43e9aa98a20f0aea
151 changes: 89 additions & 62 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,30 @@
from .common import FitterConfig, Reporter
from .eigen_reporter import EigenFitter, EigenFitterConfig

# declare AnyReporter as CcsReporter | Reporter type alias
AnyReporter = CcsReporter | Reporter


@dataclass
class ReporterTrainResult:
reporter: CcsReporter | Reporter
reporter: AnyReporter
train_loss: float | None


class MultiReporter:
def __init__(self, reporter_results: list[ReporterTrainResult]):
self.reporter_results: list[ReporterTrainResult] = reporter_results
self.reporters = [r.reporter for r in reporter_results]
train_losses = [r.train_loss for r in reporter_results]
self.train_loss = (
None if train_losses[0] is None else sum(train_losses) / len(train_losses)

Check failure on line 38 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

Argument of type "list[float | None]" cannot be assigned to parameter "__iterable" of type "Iterable[_SupportsSumNoDefaultT@sum]" in function "sum"   "list[float | None]" is incompatible with "Iterable[_SupportsSumNoDefaultT@sum]"     TypeVar "_T_co@Iterable" is covariant       Type "float | None" cannot be assigned to type "_SupportsSumWithNoDefaultGiven"         Type "float | None" cannot be assigned to type "_SupportsSumWithNoDefaultGiven"           "__add__" is not present           "__radd__" is not present (reportGeneralTypeIssues)
)

def __call__(self, h):
credences = [r(h) for r in self.reporters]
return torch.stack(credences).mean(dim=0)


@dataclass
class Elicit(Run):
"""Full specification of a reporter training run."""
Expand All @@ -43,75 +60,86 @@
def evaluate_and_save(
self,
train_loss,
reporter,
reporter: AnyReporter | MultiReporter,
train_dict,
val_dict,
lr_models,
layer,
prompt_index=None,
):
row_bufs = defaultdict(list)
for ds_name in val_dict:
val_h, val_gt, val_lm_preds = val_dict[ds_name]
train_h, train_gt, train_lm_preds = train_dict[ds_name]
meta = {"dataset": ds_name, "layer": layer}

val_credences = reporter(val_h)
train_credences = reporter(train_h)
maybe_prompt_index = (
{} if prompt_index is None else {"prompt_index": prompt_index}
)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
"train_loss": train_loss,
**maybe_prompt_index,
}
)

row_bufs["train_eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(train_gt, train_credences, mode).to_dict(),
"train_loss": train_loss,
**maybe_prompt_index,
}
)

if val_lm_preds is not None:
row_bufs["lm_eval"].append(
def eval_all(
reporter: AnyReporter | MultiReporter,
prompt_index: int | Literal["multi"],
):
val_credences = reporter(val_h)
train_credences = reporter(train_h)
prompt_index = {"prompt_index": prompt_index}

Check failure on line 81 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

Expression of type "dict[str, int | str]" cannot be assigned to declared type "int | Literal['multi']"   Type "dict[str, int | str]" cannot be assigned to type "int | Literal['multi']"     "dict[str, int | str]" is incompatible with "int"     "dict[str, int | str]" cannot be assigned to type "Literal['multi']" (reportGeneralTypeIssues)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_lm_preds, mode).to_dict(),
**maybe_prompt_index,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
"train_loss": train_loss,
**prompt_index,

Check failure on line 89 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

if train_lm_preds is not None:
row_bufs["train_lm_eval"].append(
row_bufs["train_eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(train_gt, train_lm_preds, mode).to_dict(),
**maybe_prompt_index,
**evaluate_preds(train_gt, train_credences, mode).to_dict(),
"train_loss": train_loss,
**prompt_index,

Check failure on line 99 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

for i, model in enumerate(lr_models):
row_bufs["lr_eval"].append(
{
**meta,
"ensembling": mode,
"inlp_iter": i,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
**maybe_prompt_index,
}
)
if val_lm_preds is not None:
row_bufs["lm_eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_lm_preds, mode).to_dict(),
**prompt_index,

Check failure on line 109 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

if train_lm_preds is not None:
row_bufs["train_lm_eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(
train_gt, train_lm_preds, mode
).to_dict(),
**prompt_index,

Check failure on line 121 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

for i, model in enumerate(lr_models):
row_bufs["lr_eval"].append(
{
**meta,
"ensembling": mode,
"inlp_iter": i,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
**prompt_index,

Check failure on line 132 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

Expected mapping for dictionary unpack operator (reportGeneralTypeIssues)
}
)

if isinstance(reporter, MultiReporter):
for prompt_index, reporter_result in enumerate(
reporter.reporter_results
):
eval_all(reporter_result.reporter, prompt_index)

eval_all(reporter, "multi")

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}

Expand Down Expand Up @@ -261,32 +289,31 @@
for i in range(v) # v is number of variants
]

res = []
results = []
for i, train_dict in enumerate(train_dicts):
reporters_path = self.out_dir / str(i) / "reporters"
lr_path = self.out_dir / str(i) / "lr_models"

reporter_train_result = self.train_and_save_reporter(
device, layer, reporters_path, train_dict
)

reporter = reporter_train_result.reporter
train_loss = reporter_train_result.train_loss
results.append(reporter_train_result)

lr_models = self.train_lr_model(train_dict, device, layer, lr_path)

res.append(
self.evaluate_and_save(
train_loss,
reporter,
train_dict,
val_dict,
lr_models,
layer,
prompt_index=i,
)
multi_reporter = MultiReporter(results)
train_loss = multi_reporter.train_loss

return [
self.evaluate_and_save(
train_loss,
multi_reporter,
train_dict,
val_dict,
lr_models, # TODO I don't care about this right now but

Check failure on line 313 in elk/training/train.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

"lr_models" is possibly unbound (reportUnboundVariable)
layer,
)
return res
]
else:
reporter_train_result = self.train_and_save_reporter(
device, layer, self.out_dir / "reporters", train_dict
Expand Down
Loading