Skip to content

Commit

Permalink
dep inject columns and writing
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed Mar 25, 2023
1 parent b277817 commit a204748
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 33 deletions.
18 changes: 15 additions & 3 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Literal, Optional
from typing import Literal, Optional, Callable

import torch
from simple_parsing import Serializable, field
Expand All @@ -9,6 +10,7 @@
from elk.files import elk_reporter_dir
from elk.run import Run
from elk.training.train_result import EvalStatResult
from elk.utils import select_usable_devices


@dataclass
Expand Down Expand Up @@ -49,7 +51,7 @@ def execute(self):
class Evaluate(Run):
cfg: Eval

def apply_to_single_layer(
def evaluate_reporter(
self, layer: int, devices: list[str], world_size: int = 1
) -> EvalStatResult:
"""Evaluate a single reporter on a single layer."""
Expand Down Expand Up @@ -77,4 +79,14 @@ def apply_to_single_layer(

def evaluate(self):
"""Evaluate the reporter on all layers."""
self.apply_to_layers()
devices = select_usable_devices(self.cfg.num_gpus)
num_devices = len(devices)
func: Callable[[int], EvalStatResult] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(
func=func,
num_devices=num_devices,
to_csv_line=lambda item: item.to_csv_line(),
csv_columns=EvalStatResult.csv_columns(),
)
51 changes: 27 additions & 24 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union, Iterator, TextIO, Callable, Sequence
from typing import (
TYPE_CHECKING,
Optional,
Union,
Iterator,
TextIO,
Callable,
Sequence,
TypeVar,
)

import numpy as np
import torch
Expand All @@ -18,7 +27,7 @@
from elk.files import create_output_directory, save_config, save_meta
from elk.logging import save_debug_log
from elk.training.preprocessing import normalize
from elk.training.train_result import ElicitStatResult, EvalStatResult
from elk.training.train_result import ElicitStatResult, EvalStatResult, StatResult
from elk.utils.data_utils import get_layers, select_train_val_splits
from elk.utils.gpu_utils import select_usable_devices
from elk.utils.typing import assert_type, int16_to_float32
Expand Down Expand Up @@ -88,59 +97,53 @@ def prepare_data(

return x0, x1, val_x0, val_x1, train_labels, val_labels

@abstractmethod
def apply_to_single_layer(
def apply_to_layers(
self,
layer: int,
devices: list[str],
world_size: int = 1,
) -> Union[ElicitStatResult, EvalStatResult]:
...

def apply_to_layers(self):
"""Apply a function to each layer of the dataset in parallel."""
func = self.apply_to_single_layer

devices = select_usable_devices(self.cfg.num_gpus)
num_devices = len(devices)
func: Callable[[int], StatResult],
num_devices: int,
to_csv_line: Callable[[StatResult], list[str]],
csv_columns: list[str],
):
"""Apply a function to each layer of the dataset in parallel
and writes the results to a CSV file."""
self.out_dir = assert_type(Path, self.out_dir)
with mp.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f:
# Partially apply so the function will just take the layer as an argument
fn: Callable[[int], Union[ElicitStatResult, EvalStatResult]] = partial(
func, devices=devices, world_size=num_devices
)
layers: list[int] = get_layers(self.dataset)
mapper = pool.imap_unordered if num_devices > 1 else map
# Typed as sequence for covariant typing
iterator: Sequence[Union[ElicitStatResult, EvalStatResult]] = tqdm(mapper(fn, layers), total=len(layers)) # type: ignore
iterator: Sequence[StatResult] = tqdm(mapper(func, layers), total=len(layers)) # type: ignore
write_func_to_file(
iterator=iterator,
file=f,
debug=self.cfg.debug,
dataset=self.dataset,
out_dir=self.out_dir,
skip_baseline=self.cfg.skip_baseline,
csv_columns=csv_columns,
to_csv_line=to_csv_line,
)


def write_func_to_file(
iterator: Sequence[Union[ElicitStatResult, EvalStatResult]],
iterator: Sequence[StatResult],
csv_columns: list[str],
to_csv_line: Callable[[StatResult], list[str]],
file: TextIO,
debug: bool,
dataset: DatasetDict,
out_dir: Path,
skip_baseline: bool,
) -> None:
row_buf = []
writer = csv.writer(file)
# write a single line
writer.writerow(ElicitStatResult.to_csv_columns(skip_baseline=skip_baseline))
writer.writerow(csv_columns)
try:
for row in iterator:
row_buf.append(row)
finally:
# Make sure the CSV is written even if we crash or get interrupted
for row in sorted(row_buf):
row = to_csv_line(row)
writer.writerow(row)
if debug:
save_debug_log(dataset, out_dir)
17 changes: 15 additions & 2 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .eigen_reporter import EigenReporter, EigenReporterConfig
from .reporter import OptimConfig, Reporter, ReporterConfig
from .train_result import ElicitStatResult
from ..utils import select_usable_devices


@dataclass
Expand Down Expand Up @@ -110,7 +111,7 @@ def save_baseline(self, lr_dir: Path, layer: int, lr_model: Classifier):
with open(lr_dir / f"layer_{layer}.pt", "wb") as file:
pickle.dump(lr_model, file)

def apply_to_single_layer(
def train_reporter(
self,
layer: int,
devices: list[str],
Expand Down Expand Up @@ -187,4 +188,16 @@ def get_pseudo_auroc(

def train(self):
"""Train a reporter on each layer of the network."""
self.apply_to_layers()
devices = select_usable_devices(self.cfg.num_gpus)
num_devices = len(devices)
func: Callable[[int], ElicitStatResult] = partial(
self.train_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(
func=func,
num_devices=num_devices,
to_csv_line=lambda item: item.to_csv_line(
skip_baseline=self.cfg.skip_baseline
),
csv_columns=ElicitStatResult.csv_columns(self.cfg.skip_baseline),
)
17 changes: 13 additions & 4 deletions elk/training/train_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, TypeVar

from simple_parsing import Serializable

Expand All @@ -20,8 +20,9 @@ class ElicitStatResult:
lr_acc: Optional[float] = None

cols = ["layer", "loss", "acc", "cal_acc", "auroc"]

@staticmethod
def to_csv_columns(skip_baseline: bool) -> list[str]:
def csv_columns(skip_baseline: bool) -> list[str]:
"""Return a CSV header with the column names."""
cols = [
"layer",
Expand All @@ -36,8 +37,6 @@ def to_csv_columns(skip_baseline: bool) -> list[str]:
cols += ["lr_auroc", "lr_acc"]
return cols



def to_csv_line(self, skip_baseline: bool) -> list[str]:
"""Return a CSV line with the evaluation results."""
items = [
Expand All @@ -58,3 +57,13 @@ def to_csv_line(self, skip_baseline: bool) -> list[str]:
@dataclass
class EvalStatResult:
...

@staticmethod
def csv_columns() -> list[str]:
...

def to_csv_line(self) -> list[str]:
...


StatResult = TypeVar("StatResult", ElicitStatResult, EvalStatResult)

0 comments on commit a204748

Please sign in to comment.