Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into eval_dirs
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristyKoh committed Apr 13, 2023
2 parents f35a6a8 + 361fb9b commit ba090b7
Show file tree
Hide file tree
Showing 34 changed files with 1,542 additions and 811 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
run-tests:
strategy:
matrix:
python-versions: [ 3.9, "3.10", "3.11" ]
python-versions: [ "3.10", "3.11" ]
os: [ ubuntu-latest, macos-latest ]
runs-on: ${{ matrix.os }}
steps:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.257'
rev: 'v0.0.261'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
10 changes: 9 additions & 1 deletion elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from .extraction import Extract, extract_hiddens
from .training import EigenReporter, EigenReporterConfig
from .truncated_eigh import truncated_eigh

__all__ = ["extract_hiddens", "Extract"]
__all__ = [
"EigenReporter",
"EigenReporterConfig",
"extract_hiddens",
"Extract",
"truncated_eigh",
]
3 changes: 1 addition & 2 deletions elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Main entry point for `elk`."""

from dataclasses import dataclass
from typing import Union

from simple_parsing import ArgumentParser

Expand All @@ -14,7 +13,7 @@
class Command:
"""Some top-level command"""

command: Union[Elicit, Eval, Extract]
command: Elicit | Eval | Extract

def execute(self):
return self.command.execute()
Expand Down
230 changes: 0 additions & 230 deletions elk/eigsh.py

This file was deleted.

56 changes: 32 additions & 24 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
from pathlib import Path
from typing import Callable, Literal, Optional

import pandas as pd
import torch
from simple_parsing.helpers import Serializable, field

from elk.evaluation.evaluate_log import EvalLog
from elk.extraction.extraction import Extract
from elk.run import Run
from elk.training import Reporter

from ..extraction.extraction import Extract
from ..files import elk_reporter_dir
from ..utils import (
select_usable_devices,
)
from ..run import Run
from ..training import Reporter
from ..training.baseline import evaluate_baseline, load_baseline
from ..utils import select_usable_devices


@dataclass
Expand All @@ -41,7 +39,7 @@ class Eval(Serializable):
debug: bool = False
out_dir: Optional[Path] = None
num_gpus: int = -1

skip_baseline: bool = False
concatenated_layer_offset: int = 0

def execute(self):
Expand All @@ -60,18 +58,18 @@ class Evaluate(Run):

def evaluate_reporter(
self, layer: int, devices: list[str], world_size: int = 1
) -> EvalLog:
) -> pd.Series:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)

_, _, test_x0, test_x1, _, test_labels = self.prepare_data(
_, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data(
device,
layer,
)

reporter_path = (
elk_reporter_dir() / self.cfg.source / "reporters" / f"layer_{layer}.pt"
)
experiment_dir = elk_reporter_dir() / self.cfg.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter.eval()

Expand All @@ -81,24 +79,34 @@ def evaluate_reporter(
test_x1,
)

return EvalLog(
layer=layer,
eval_result=test_result,
stats_row = pd.Series(
{
"layer": layer,
**test_result._asdict(),
}
)

lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_baseline and lr_dir.exists():
lr_model = load_baseline(lr_dir, layer)
lr_model.eval()
lr_auroc, lr_acc = evaluate_baseline(
lr_model.cuda(), test_x0.cuda(), test_x1.cuda(), test_labels
)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc

return stats_row

def evaluate(self):
"""Evaluate the reporter on all layers."""
devices = select_usable_devices(
self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem
)

num_devices = len(devices)
func: Callable[[int], EvalLog] = partial(
func: Callable[[int], pd.Series] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(
func=func,
num_devices=num_devices,
to_csv_line=lambda item: item.to_csv_line(),
csv_columns=EvalLog.csv_columns(),
)
self.apply_to_layers(func=func, num_devices=num_devices)
Loading

0 comments on commit ba090b7

Please sign in to comment.