Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple datasets refactor #189

Merged
merged 17 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions elk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
from simple_parsing import ArgumentParser

from elk.evaluation.evaluate import Eval
from elk.extraction.extraction import Extract
from elk.training.train import Elicit


@dataclass
class Command:
"""Some top-level command"""

command: Elicit | Eval | Extract
command: Elicit | Eval

def execute(self):
return self.command.execute()


def run():
parser = ArgumentParser(add_help=False)
parser = ArgumentParser(add_help=False, add_config_path_arg=True)
parser.add_arguments(Command, dest="run")
args = parser.parse_args()
run: Command = args.run
Expand Down
64 changes: 31 additions & 33 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Literal, Optional
from typing import Callable

import pandas as pd
import torch
Expand All @@ -11,7 +11,7 @@
from ..files import elk_reporter_dir
from ..run import Run
from ..training import Reporter
from ..training.baseline import evaluate_baseline, load_baseline
from ..training.supervised import evaluate_supervised
from ..utils import select_usable_devices


Expand All @@ -34,11 +34,11 @@ class Eval(Serializable):

data: Extract
source: str = field(positional=True)
normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly"

debug: bool = False
out_dir: Optional[Path] = None
out_dir: Path | None = None
num_gpus: int = -1
min_gpu_mem: int | None = None
skip_baseline: bool = False
concatenated_layer_offset: int = 0

Expand All @@ -58,55 +58,53 @@ class Evaluate(Run):

def evaluate_reporter(
self, layer: int, devices: list[str], world_size: int = 1
) -> pd.Series:
) -> pd.DataFrame:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)

_, _, test_x0, test_x1, _, test_labels, _ = self.prepare_data(
device,
layer,
)
val_output = self.prepare_data(device, layer, "val")

experiment_dir = elk_reporter_dir() / self.cfg.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter.eval()

test_result = reporter.score(
test_labels,
test_x0,
test_x1,
)

stats_row = pd.Series(
{
"layer": layer,
**test_result._asdict(),
}
)
row_buf = []
for ds_name, (val_x0, val_x1, val_gt, _) in val_output.items():
val_result = reporter.score(
val_gt,
val_x0,
val_x1,
)

lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_baseline and lr_dir.exists():
lr_model = load_baseline(lr_dir, layer)
lr_model.eval()
lr_auroc, lr_acc = evaluate_baseline(
lr_model.cuda(), test_x0.cuda(), test_x1.cuda(), test_labels
stats_row = pd.Series(
{
"dataset": ds_name,
"layer": layer,
**val_result._asdict(),
}
)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc
lr_dir = experiment_dir / "lr_models"
if not self.cfg.skip_baseline and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_model = torch.load(f, map_location=device).eval()

lr_auroc, lr_acc = evaluate_supervised(lr_model, val_x0, val_x1, val_gt)

stats_row["lr_auroc"] = lr_auroc
stats_row["lr_acc"] = lr_acc

return stats_row
return pd.DataFrame(row_buf)

def evaluate(self):
"""Evaluate the reporter on all layers."""
devices = select_usable_devices(
self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem
self.cfg.num_gpus, min_memory=self.cfg.min_gpu_mem
)

num_devices = len(devices)
func: Callable[[int], pd.Series] = partial(
func: Callable[[int], pd.DataFrame] = partial(
self.evaluate_reporter, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)
2 changes: 1 addition & 1 deletion elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from datasets import Features, IterableDataset
from torch.utils.data import IterableDataset as TorchIterableDataset

from ..math_util import stochastic_round_constrained
from ..utils import infer_label_column
from ..utils.math_util import stochastic_round_constrained
from ..utils.typing import assert_type


Expand Down
52 changes: 33 additions & 19 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import copy
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Any, Iterable, Literal, Optional
from typing import Any, Iterable, Literal

import torch
from datasets import (
Expand All @@ -23,6 +23,7 @@
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput

from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
convert_span,
Expand All @@ -48,7 +49,6 @@ class Extract(Serializable):
layer_stride: Shortcut for setting `layers` to `range(0, num_layers, stride)`.
token_loc: The location of the token to extract hidden states from. Can be
either "first", "last", or "mean". Defaults to "last".
min_gpu_mem: Minimum amount of free memory (in bytes) required to select a GPU.
"""

prompts: PromptConfig
Expand All @@ -57,8 +57,6 @@ class Extract(Serializable):
layers: tuple[int, ...] = ()
layer_stride: InitVar[int] = 1
token_loc: Literal["first", "last", "mean"] = "last"
min_gpu_mem: Optional[int] = None
num_gpus: int = -1

def __post_init__(self, layer_stride: int):
if self.layers and layer_stride > 1:
Expand All @@ -74,8 +72,16 @@ def __post_init__(self, layer_stride: int):
)
self.layers = tuple(range(0, config.num_hidden_layers, layer_stride))

def execute(self):
extract(cfg=self, num_gpus=self.num_gpus)
def explode(self) -> list["Extract"]:
"""Explode this config into a list of configs, one for each layer."""
copies = []

for prompt_cfg in self.prompts.explode():
cfg = copy(self)
cfg.prompts = prompt_cfg
copies.append(cfg)

return copies


@torch.no_grad()
Expand All @@ -94,8 +100,11 @@ def extract_hiddens(
if rank != 0:
logging.disable(logging.CRITICAL)

ds_names = cfg.prompts.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

prompt_ds = load_prompts(
*cfg.prompts.datasets,
ds_names[0],
split_type=split_type,
stream=cfg.prompts.stream,
rank=rank,
Expand Down Expand Up @@ -240,14 +249,19 @@ def _extraction_worker(**kwargs):
yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()})


def extract(cfg: "Extract", num_gpus: int = -1) -> DatasetDict:
def extract(
cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None
) -> DatasetDict:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""

def get_splits() -> SplitDict:
available_splits = assert_type(SplitDict, info.splits)
train_name, val_name = select_train_val_splits(available_splits)
print(f"Using '{train_name}' for training and '{val_name}' for validation")

print(
# Cyan color for dataset name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a nice touch :)

f"\033[36m{info.builder_name}\033[0m: using '{train_name}' for training and"
f" '{val_name}' for validation"
)
limit_list = cfg.prompts.max_examples

return SplitDict(
Expand All @@ -263,11 +277,15 @@ def get_splits() -> SplitDict:
)

model_cfg = AutoConfig.from_pretrained(cfg.model)
num_variants = cfg.prompts.num_variants

ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)

num_variants = cfg.prompts.num_variants
if num_variants < 0:
prompter = DatasetTemplates(ds_name, config_name)
num_variants = len(prompter.templates)

layer_cols = {
f"hidden_{layer}": Array3D(
dtype="int16",
Expand Down Expand Up @@ -297,22 +315,18 @@ def get_splits() -> SplitDict:
length=num_variants,
)

devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem)

# Prevent the GPU-related config options from invalidating the cache
_cfg = copy(cfg)
_cfg.min_gpu_mem = None
_cfg.num_gpus = -1

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
builders = {
split_name: _GeneratorBuilder(
builder_name=info.builder_name,
config_name=info.config_name,
cache_dir=None,
features=Features({**layer_cols, **other_cols}),
generator=_extraction_worker,
split_name=split_name,
split_info=split_info,
gen_kwargs=dict(
cfg=[_cfg] * len(devices),
cfg=[cfg] * len(devices),
device=devices,
rank=list(range(len(devices))),
split_type=[split_name] * len(devices),
Expand Down
42 changes: 30 additions & 12 deletions elk/extraction/generator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional

import datasets
from datasets import Features
from typing import Any, Callable

from datasets import (
BuilderConfig,
DatasetInfo,
Features,
GeneratorBasedBuilder,
SplitInfo,
)
from datasets.splits import NamedSplit


@dataclass
class _GeneratorConfig(datasets.BuilderConfig):
generator: Optional[Callable] = None
class _GeneratorConfig(BuilderConfig):
generator: Callable | None = None
gen_kwargs: dict[str, Any] = field(default_factory=dict)
features: Optional[datasets.Features] = None
features: Features | None = None

def create_config_id(
self, config_kwargs: dict, custom_features: Features | None
Expand All @@ -37,28 +42,41 @@ class _SplitGenerator:
"""

name: str
split_info: datasets.SplitInfo
gen_kwargs: Dict = field(default_factory=dict)
split_info: SplitInfo
gen_kwargs: dict = field(default_factory=dict)

def __post_init__(self):
self.name = str(self.name) # Make sure we convert NamedSplits in strings
NamedSplit(self.name) # check that it's a valid split name


class _GeneratorBuilder(datasets.GeneratorBasedBuilder):
class _GeneratorBuilder(GeneratorBasedBuilder):
"""Patched version of `datasets.Generator` allowing for splits besides `train`"""

BUILDER_CONFIG_CLASS = _GeneratorConfig
config: _GeneratorConfig

def __init__(self, split_name: str, split_info: datasets.SplitInfo, **kwargs):
def __init__(
self,
builder_name: str | None,
config_name: str | None,
split_name: str,
split_info: SplitInfo,
**kwargs,
):
self.split_name = split_name
self.split_info = split_info

super().__init__(**kwargs)

# Weirdly we need to set DatasetInfo.builder_name and DatasetInfo.config_name
# here, not in _info, because super().__init__ modifies them
self.info.builder_name = builder_name
self.info.config_name = config_name

def _info(self):
return datasets.DatasetInfo(features=self.config.features)
# Use the same builder and config name as the original builder
return DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
return [
Expand Down
Loading