Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binarized meta-templates; some extraction refactoring #218

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bbee489
Initial support for FEVER
norabelrose Apr 22, 2023
5ba1ddd
Start saving and fitting a reporter to the input embeddings
norabelrose Apr 22, 2023
3b1f74d
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
51ba54f
Rename layer 0 to 'input' to make it more clear
norabelrose Apr 22, 2023
544b485
Actually rename layer 0 correctly
norabelrose Apr 22, 2023
43da44e
Handle layer_stride correctly
norabelrose Apr 22, 2023
9056e00
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
756fa53
label_choices
norabelrose Apr 22, 2023
93b7ae0
Clean up train and eval commands; do transfer in sweep
norabelrose Apr 22, 2023
57d0b8b
Support INLP and split eval output into multiple CSVs
norabelrose Apr 22, 2023
228a6a0
Merge branch 'inlp' into template-filtering
norabelrose Apr 22, 2023
b086f0b
Merge branch 'inlp' into template-filtering
norabelrose Apr 25, 2023
934cd54
Log ensembled metrics
norabelrose Apr 26, 2023
dff69bf
Fixing pyright version
norabelrose Apr 26, 2023
b181d3e
Merge remote-tracking branch 'origin/main' into ensembling
norabelrose Apr 26, 2023
15254bf
Merge main
norabelrose Apr 26, 2023
69c2d55
Tons of stuff, preparing for sciq_binary experiment
norabelrose Apr 27, 2023
960ff01
Support --binarize again
norabelrose Apr 27, 2023
c9e62ea
Partial support for truthful_qa
norabelrose Apr 27, 2023
eb71a6c
Merge branch 'main' into template-filtering
norabelrose Apr 29, 2023
88bb15e
Merge remote-tracking branch 'origin/main' into template-filtering
norabelrose Apr 29, 2023
c648ff0
Remove crap
norabelrose Apr 29, 2023
ef12130
EleutherAI/truthful_qa_mc
norabelrose Apr 29, 2023
5d60ebd
Update templates
norabelrose Apr 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'inlp' into template-filtering
  • Loading branch information
norabelrose committed Apr 25, 2023
commit b086f0b31df478395b6bf3819ee2883a2d9be11f
3 changes: 3 additions & 0 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __post_init__(self):
transfer_dir = elk_reporter_dir() / self.source / "transfer_eval"
self.out_dir = transfer_dir / "+".join(self.data.prompts.datasets)

def execute(self, highlight_color: str = "cyan"):
return super().execute(highlight_color, split_type="val")

def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> dict[str, pd.DataFrame]:
Expand Down
78 changes: 40 additions & 38 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Array2D,
Array3D,
DatasetDict,
DatasetInfo,
DownloadMode,
Features,
Sequence,
Expand All @@ -35,6 +36,7 @@
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
select_split,
select_train_val_splits,
select_usable_devices,
)
Expand Down Expand Up @@ -264,39 +266,8 @@ def _extraction_worker(**kwargs):
yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()})


def extract(
cfg: "Extract",
*,
disable_cache: bool = False,
highlight_color: str = "cyan",
num_gpus: int = -1,
min_gpu_mem: int | None = None,
) -> DatasetDict:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""

def get_splits() -> SplitDict:
available_splits = assert_type(SplitDict, info.splits)
train_name, val_name = select_train_val_splits(available_splits)

pretty_name = colorize(assert_type(str, info.builder_name), highlight_color)
print(
f"{pretty_name}: using '{train_name}' for training "
f"and '{val_name}' for validation"
)
limit_list = cfg.prompts.max_examples

return SplitDict(
{
k: SplitInfo(
name=k,
num_examples=min(limit, v.num_examples) * len(cfg.prompts.datasets),
dataset_name=v.dataset_name,
)
for limit, (k, v) in zip(limit_list, available_splits.items())
},
dataset_name=available_splits.dataset_name,
)

def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]:
"""Return the HuggingFace `Features` corresponding to an `Extract` config."""
model_cfg = AutoConfig.from_pretrained(cfg.model)

ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
Expand Down Expand Up @@ -345,16 +316,47 @@ def get_splits() -> SplitDict:
dtype="float32",
)

return info, Features({**layer_cols, **other_cols})


def extract(
cfg: "Extract",
*,
disable_cache: bool = False,
highlight_color: str = "cyan",
num_gpus: int = -1,
min_gpu_mem: int | None = None,
split_type: Literal["train", "val", None] = None,
) -> DatasetDict:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""
info, features = hidden_features(cfg)

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
limit_list = cfg.prompts.max_examples
splits = assert_type(SplitDict, info.splits)

if split_type is None:
train, val = select_train_val_splits(splits)
pretty_name = colorize(assert_type(str, info.builder_name), highlight_color)
print(f"{pretty_name}: using '{train}' for training and '{val}' for validation")
splits = SplitDict({train: splits[train], val: splits[val]})
else:
# Remove the split we're not using
del limit_list[1 if split_type == "train" else 0]
split_name = select_split(splits, split_type)
splits = SplitDict({split_name: splits[split_name]})

builders = {
split_name: _GeneratorBuilder(
builder_name=info.builder_name,
config_name=info.config_name,
cache_dir=None,
features=Features({**layer_cols, **other_cols}),
features=features,
generator=_extraction_worker,
split_name=split_name,
split_info=split_info,
split_info=SplitInfo(
name=split_name,
num_examples=min(limit, v.num_examples) * len(cfg.prompts.datasets),
dataset_name=v.dataset_name,
),
gen_kwargs=dict(
cfg=[cfg] * len(devices),
device=devices,
Expand All @@ -363,7 +365,7 @@ def get_splits() -> SplitDict:
world_size=[len(devices)] * len(devices),
),
)
for (split_name, split_info) in get_splits().items()
for limit, (split_name, v) in zip(limit_list, splits.items())
}
import multiprocess as mp

Expand Down
7 changes: 0 additions & 7 deletions elk/extraction/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ class _GeneratorBuilder(GeneratorBasedBuilder):

def __init__(
self,
builder_name: str | None,
config_name: str | None,
split_name: str,
split_info: SplitInfo,
**kwargs,
Expand All @@ -69,11 +67,6 @@ def __init__(

super().__init__(**kwargs)

# Weirdly we need to set DatasetInfo.builder_name and DatasetInfo.config_name
# here, not in _info, because super().__init__ modifies them
self.info.builder_name = builder_name
self.info.config_name = config_name

def _info(self):
# Use the same builder and config name as the original builder
return DatasetInfo(features=self.config.features)
Expand Down
7 changes: 6 additions & 1 deletion elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,19 @@ class Run(ABC, Serializable):
out_dir: Path | None = None
disable_cache: bool = field(default=False, to_dict=False)

def execute(self, highlight_color: str = "cyan"):
def execute(
self,
highlight_color: str = "cyan",
split_type: Literal["train", "val", None] = None,
):
self.datasets = [
extract(
cfg,
disable_cache=self.disable_cache,
highlight_color=highlight_color,
num_gpus=self.num_gpus,
min_gpu_mem=self.min_gpu_mem,
split_type=split_type,
)
for cfg in self.data.explode()
]
Expand Down
2 changes: 2 additions & 0 deletions elk/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
has_multiple_configs,
infer_label_column,
infer_num_classes,
select_split,
select_train_val_splits,
)
from .gpu_utils import select_usable_devices
Expand Down Expand Up @@ -33,6 +34,7 @@
"int16_to_float32",
"is_autoregressive",
"pytree_map",
"select_split",
"select_train_val_splits",
"select_usable_devices",
"stochastic_round_constrained",
Expand Down
25 changes: 17 additions & 8 deletions elk/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
from functools import cache
from random import Random
from typing import Any, Iterable
from typing import Any, Iterable, Literal

from datasets import (
ClassLabel,
Expand All @@ -15,6 +15,13 @@
from ..promptsource.templates import Template
from .typing import assert_type

# Lower values are more "train-like" and higher values are more "test-like".
PRIORITIES = {
Split.TRAIN: 0,
Split.VALIDATION: 1,
Split.TEST: 2,
}


def get_columns_all_equal(dataset: DatasetDict) -> list[str]:
"""Get columns of a `DatasetDict`, asserting all splits have the same columns."""
Expand Down Expand Up @@ -50,16 +57,18 @@ def has_multiple_configs(ds_name: str) -> bool:
return len(get_dataset_config_names(ds_name)) > 1


def select_split(raw_splits: Iterable[str], split_type: Literal["train", "val"]) -> str:
"""Return the train or validation split to use, given an Iterable of splits."""
assert split_type in ("train", "val")

reduce_fn = min if split_type == "train" else max
return reduce_fn(raw_splits, key=lambda k: PRIORITIES.get(k, 100)) # type: ignore


def select_train_val_splits(raw_splits: Iterable[str]) -> tuple[str, str]:
"""Return splits to use for train and validation, given an Iterable of splits."""

priorities = {
Split.TRAIN: 0,
Split.VALIDATION: 1,
Split.TEST: 2,
}

splits = sorted(raw_splits, key=lambda k: priorities.get(k, 100)) # type: ignore
splits = sorted(raw_splits, key=lambda k: PRIORITIES.get(k, 100)) # type: ignore
assert len(splits) >= 2, "Must have at least two of train, val, and test splits"

return tuple(splits[:2])
Expand Down