Skip to content

Commit

Permalink
refactor generics
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed Mar 26, 2023
1 parent 1fdcce1 commit d2b5da7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
11 changes: 6 additions & 5 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from elk.evaluation.evaluate import Eval
from elk.training.train import Elicit

A = TypeVar("A", EvalLog, ElicitLog)
"""A generic log type that contains a layer field
The layer field is used to sort the logs by layer."""
Log = TypeVar("Log", EvalLog, ElicitLog)


@dataclass
Expand Down Expand Up @@ -97,19 +99,18 @@ def prepare_data(

def apply_to_layers(
self,
func: Callable[[int], A],
func: Callable[[int], Log],
num_devices: int,
to_csv_line: Callable[[A], list[str]],
to_csv_line: Callable[[Log], list[str]],
csv_columns: list[str],
):
"""Apply a function to each layer of the dataset in parallel
and writes the results to a CSV file."""
self.out_dir = assert_type(Path, self.out_dir)
with mp.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f:
# Partially apply so the function will just take the layer as an argument
layers: list[int] = get_layers(self.dataset)
mapper = pool.imap_unordered if num_devices > 1 else map
iterator: Iterator[A] = tqdm(mapper(func, layers), total=len(layers)) # type: ignore
iterator: Iterator[Log] = tqdm(mapper(func, layers), total=len(layers)) # type: ignore
write_iterator_to_file(
iterator=iterator,
file=f,
Expand Down
6 changes: 3 additions & 3 deletions elk/utils/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from datasets import DatasetDict

from elk.logging import save_debug_log
from elk.run import A
from elk.run import Log


def write_iterator_to_file(
iterator: Iterator[A],
iterator: Iterator[Log],
csv_columns: list[str],
to_csv_line: Callable[[A], list[str]],
to_csv_line: Callable[[Log], list[str]],
file: TextIO,
debug: bool,
dataset: DatasetDict,
Expand Down

0 comments on commit d2b5da7

Please sign in to comment.