Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train probe per prompt #271

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed index passing
  • Loading branch information
derpyplops committed Jul 19, 2023
commit 9c2def0df61dd966a5901df29a5e2c1519c6673d
9 changes: 9 additions & 0 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def apply_to_layer(
device = self.get_device(devices, world_size)
val_output = self.prepare_data(device, layer, "val")

val_output = {
ds_name: (
train_h[:, self.prompt_indices, ...],
train_gt,
lm_preds[:, self.prompt_indices, ...] if lm_preds is not None else None,
)
for ds_name, (train_h, train_gt, lm_preds) in val_output.items()
}

experiment_dir = elk_reporter_dir() / self.source

def load_reporter() -> AnyReporter | MultiReporter:
Expand Down
2 changes: 0 additions & 2 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ def prepare_data(
split = ds[key].with_format("torch", device=device, dtype=torch.int16)
labels = assert_type(Tensor, split["label"])
hiddens = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"]))
if self.prompt_indices:
hiddens = hiddens[:, self.prompt_indices]

with split.formatted_as("torch", device=device):
has_preds = "model_logits" in split.features
Expand Down
1 change: 1 addition & 0 deletions elk/training/multi_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class ReporterTrainResult:
reporter: AnyReporter
train_loss: float | None
prompt_index: int | None


class MultiReporter:
Expand Down
70 changes: 48 additions & 22 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,16 @@ def evaluate_and_save(

def eval_all(
reporter: AnyReporter | MultiReporter,
prompt_index: int | Literal["multi"],
prompt_index: int | Literal["multi"] | None = None,
i: int = 0,
):
val_credences = reporter(val_h)
train_credences = reporter(train_h)
if isinstance(prompt_index, int):
val_credences = reporter(val_h[:, [prompt_index], :, :])
train_credences = reporter(train_h[:, [prompt_index], :, :])
else:
# TODO implement diagonal
val_credences = reporter(val_h)
train_credences = reporter(train_h)
prompt_index = {"prompt_index": prompt_index}
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
Expand Down Expand Up @@ -96,10 +102,11 @@ def eval_all(
)

if isinstance(reporter, MultiReporter):
for prompt_index, reporter_result in enumerate(reporter.reporter_results):
eval_all(reporter_result.reporter, prompt_index)

eval_all(reporter, "multi")
for reporter_result in reporter.reporter_results:
eval_all(reporter_result.reporter, reporter_result.prompt_index)
eval_all(reporter, prompt_index="multi")
else:
eval_all(reporter, prompt_index=None)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}

Expand Down Expand Up @@ -154,9 +161,10 @@ def make_eval(self, model, eval_dataset):

# Create a separate function to handle the reporter training.
def train_and_save_reporter(
self, device, layer, out_dir, train_dict
self, device, layer, out_dir, train_dict, prompt_index=None
) -> ReporterTrainResult:
(first_train_h, train_gt, _), *rest = train_dict.values() # TODO can remove?
breakpoint()
(_, v, k, d) = first_train_h.shape
if not all(other_h.shape[-1] == d for other_h, _, _ in rest):
raise ValueError("All datasets must have the same hidden state size")
Expand Down Expand Up @@ -209,7 +217,7 @@ def train_and_save_reporter(
out_dir.mkdir(parents=True, exist_ok=True)
torch.save(reporter, out_dir / f"layer_{layer}.pt")

return ReporterTrainResult(reporter, train_loss)
return ReporterTrainResult(reporter, train_loss, prompt_index)

def train_lr_model(self, train_dict, device, layer, out_dir) -> list[Classifier]:
if self.supervised != "none":
Expand Down Expand Up @@ -241,7 +249,8 @@ def apply_to_layer(
self.make_reproducible(seed=self.net.seed + layer)
device = self.get_device(devices, world_size)

train_dict = self.prepare_data(device, layer, "train")
train_dict = self.prepare_data(device, layer, "train") # prepare data no
# longer does anything on prompt indices
val_dict = self.prepare_data(device, layer, "val")

(first_train_h, train_gt, _), *rest = train_dict.values()
Expand All @@ -251,49 +260,66 @@ def apply_to_layer(
# reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))

if probe_per_prompt:
train_dicts = [
prompt_indices = self.prompt_indices if self.prompt_indices else range(v)
prompt_train_dicts = [
{
ds_name: (
train_h[:, i : i + 1, ...],
train_h[:, [prompt_index], ...],
train_gt,
lm_preds[:, i : i + 1, ...] if lm_preds is not None else None,
lm_preds[:, [prompt_index], ...]
if lm_preds is not None
else None,
)
}
for ds_name, (train_h, _, lm_preds) in train_dict.items()
for i in range(v) # v is number of variants
for prompt_index in prompt_indices # v is number of variants
]

results = []
for i, train_dict in enumerate(train_dicts):
# format i as a 2 digit string, assumes that there will never be more
# than 100 prompts
str_i = str(i).zfill(2)

for prompt_index, prompt_train_dict in zip(
prompt_indices, prompt_train_dicts
):
assert prompt_index < 100 # format i as a 2 digit string
str_i = str(prompt_index).zfill(2)
base = self.out_dir / "reporters" / f"prompt_{str_i}"
reporters_path = base / "reporters"
lr_path = base / "lr_models"

reporter_train_result = self.train_and_save_reporter(
device, layer, reporters_path, train_dict
device, layer, reporters_path, prompt_train_dict, prompt_index
)
results.append(reporter_train_result)

lr_models = self.train_lr_model(train_dict, device, layer, lr_path)
lr_models = self.train_lr_model(
prompt_train_dict, device, layer, lr_path
)

maybe_multi_reporter = MultiReporter(results)
train_loss = maybe_multi_reporter.train_loss

# TODO fix lr_models

else:
prompt_train_dict = {
ds_name: (
train_h[:, self.prompt_indices, ...],
train_gt,
lm_preds[:, self.prompt_indices, ...]
if lm_preds is not None
else None,
)
for ds_name, (train_h, _, lm_preds) in train_dict.items()
}
reporter_train_result = self.train_and_save_reporter(
device, layer, self.out_dir / "reporters", train_dict
device, layer, self.out_dir / "reporters", prompt_train_dict
)

maybe_multi_reporter = reporter_train_result.reporter
train_loss = reporter_train_result.train_loss

lr_models = self.train_lr_model(
train_dict, device, layer, self.out_dir / "lr_models"
prompt_train_dict, device, layer, self.out_dir / "lr_models"
)

return evaluate_and_save(
Expand Down