Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train probe per prompt #271

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add more types and sorting
  • Loading branch information
derpyplops committed Jul 18, 2023
commit 7701c291fda586b5263096ad5f514992105a3256
13 changes: 10 additions & 3 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
select_usable_devices,
)

PreparedData = dict[str, tuple[Tensor, Tensor, Tensor | None]]


@dataclass
class Run(ABC, Serializable):
Expand Down Expand Up @@ -132,7 +134,7 @@ def get_device(self, devices, world_size: int) -> str:

def prepare_data(
self, device: str, layer: int, split_type: Literal["train", "val"]
) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]:
) -> PreparedData:
"""Prepare data for the specified layer and split type."""
out = {}

Expand Down Expand Up @@ -196,9 +198,14 @@ def apply_to_layers(
sortby = ["layer", "ensembling"]
if "prompt_index" in dfs[0].columns:
sortby.append("prompt_index")
# make the prompt index third col

df = pd.concat(dfs).sort_values(by=sortby)

# Move prompt_index to the 2'th column
cols = list(df.columns)
cols.insert(2, cols.pop(cols.index("prompt_index")))
df = df.reindex(columns=cols)

# Save the CSV
out_path = self.out_dir / f"{name}.csv"
df.round(4).to_csv(out_path, index=False)
if self.debug:
Expand Down
72 changes: 34 additions & 38 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@

from ..evaluation import Eval
from ..metrics import evaluate_preds, to_one_hot
from ..run import Run
from ..run import PreparedData, Run
from ..training.supervised import train_supervised
from . import Classifier
from .ccs_reporter import CcsConfig, CcsReporter
from .common import FitterConfig, Reporter
from .eigen_reporter import EigenFitter, EigenFitterConfig

# declare AnyReporter as CcsReporter | Reporter type alias
AnyReporter = CcsReporter | Reporter


Expand All @@ -33,25 +33,27 @@ class MultiReporter:
def __init__(self, reporter_results: list[ReporterTrainResult]):
self.reporter_results: list[ReporterTrainResult] = reporter_results
self.reporters = [r.reporter for r in reporter_results]
train_losses = [r.train_loss for r in reporter_results] if reporter_results[
0].train_loss \
is not None else None
self.train_loss = sum(train_losses) / len(
train_losses
) if train_losses is not None else None
train_losses = (
[r.train_loss for r in reporter_results]
if reporter_results[0].train_loss is not None
else None
)
self.train_loss = (
sum(train_losses) / len(train_losses) if train_losses is not None else None
)

def __call__(self, h):
credences = [r(h) for r in self.reporters]
return torch.stack(credences).mean(dim=0)


def evaluate_and_save(
train_loss,
train_loss: float | None,
reporter: AnyReporter | MultiReporter,
train_dict,
val_dict,
lr_models,
layer,
train_dict: PreparedData,
val_dict: PreparedData,
lr_models: list[Classifier],
layer: int,
):
row_bufs = defaultdict(list)
for ds_name in val_dict:
Expand Down Expand Up @@ -102,9 +104,7 @@ def eval_all(
{
**meta,
"ensembling": mode,
**evaluate_preds(
train_gt, train_lm_preds, mode
).to_dict(),
**evaluate_preds(train_gt, train_lm_preds, mode).to_dict(),
**prompt_index,
}
)
Expand All @@ -121,9 +121,7 @@ def eval_all(
)

if isinstance(reporter, MultiReporter):
for prompt_index, reporter_result in enumerate(
reporter.reporter_results
):
for prompt_index, reporter_result in enumerate(reporter.reporter_results):
eval_all(reporter_result.reporter, prompt_index)

eval_all(reporter, "multi")
Expand Down Expand Up @@ -238,7 +236,7 @@ def train_and_save_reporter(

return ReporterTrainResult(reporter, train_loss)

def train_lr_model(self, train_dict, device, layer, out_dir):
def train_lr_model(self, train_dict, device, layer, out_dir) -> list[Classifier]:
if self.supervised != "none":
lr_models = train_supervised(
train_dict,
Expand Down Expand Up @@ -281,9 +279,9 @@ def apply_to_layer(
train_dicts = [
{
ds_name: (
train_h[:, i: i + 1, ...],
train_h[:, i : i + 1, ...],
train_gt,
lm_preds[:, i: i + 1, ...] if lm_preds is not None else None,
lm_preds[:, i : i + 1, ...] if lm_preds is not None else None,
)
}
for ds_name, (train_h, _, lm_preds) in train_dict.items()
Expand All @@ -292,8 +290,12 @@ def apply_to_layer(

results = []
for i, train_dict in enumerate(train_dicts):
reporters_path = self.out_dir / str(i) / "reporters"
lr_path = self.out_dir / str(i) / "lr_models"
# format i as a 2 digit string, assumes that there will never be more
# than 100 prompts
str_i = str(i).zfill(2)
base = self.out_dir / "reporters" / f"prompt_{str_i}"
reporters_path = base / "reporters"
lr_path = base / "lr_models"

reporter_train_result = self.train_and_save_reporter(
device, layer, reporters_path, train_dict
Expand All @@ -302,29 +304,23 @@ def apply_to_layer(

lr_models = self.train_lr_model(train_dict, device, layer, lr_path)

multi_reporter = MultiReporter(results)
train_loss = multi_reporter.train_loss
maybe_multi_reporter = MultiReporter(results)
train_loss = maybe_multi_reporter.train_loss

# TODO fix lr_models

return evaluate_and_save(
train_loss,
multi_reporter,
train_dict,
val_dict,
lr_models, # TODO I don't care about this right now but
layer,
)
else:
reporter_train_result = self.train_and_save_reporter(
device, layer, self.out_dir / "reporters", train_dict
)

reporter = reporter_train_result.reporter
maybe_multi_reporter = reporter_train_result.reporter
train_loss = reporter_train_result.train_loss

lr_models = self.train_lr_model(
train_dict, device, layer, self.out_dir / "lr_models"
)

return evaluate_and_save(
train_loss, reporter, train_dict, val_dict, lr_models, layer
)
return evaluate_and_save(
train_loss, maybe_multi_reporter, train_dict, val_dict, lr_models, layer
)