Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binarized meta-templates; some extraction refactoring #218

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bbee489
Initial support for FEVER
norabelrose Apr 22, 2023
5ba1ddd
Start saving and fitting a reporter to the input embeddings
norabelrose Apr 22, 2023
3b1f74d
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
51ba54f
Rename layer 0 to 'input' to make it more clear
norabelrose Apr 22, 2023
544b485
Actually rename layer 0 correctly
norabelrose Apr 22, 2023
43da44e
Handle layer_stride correctly
norabelrose Apr 22, 2023
9056e00
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
756fa53
label_choices
norabelrose Apr 22, 2023
93b7ae0
Clean up train and eval commands; do transfer in sweep
norabelrose Apr 22, 2023
57d0b8b
Support INLP and split eval output into multiple CSVs
norabelrose Apr 22, 2023
228a6a0
Merge branch 'inlp' into template-filtering
norabelrose Apr 22, 2023
b086f0b
Merge branch 'inlp' into template-filtering
norabelrose Apr 25, 2023
934cd54
Log ensembled metrics
norabelrose Apr 26, 2023
dff69bf
Fixing pyright version
norabelrose Apr 26, 2023
b181d3e
Merge remote-tracking branch 'origin/main' into ensembling
norabelrose Apr 26, 2023
15254bf
Merge main
norabelrose Apr 26, 2023
69c2d55
Tons of stuff, preparing for sciq_binary experiment
norabelrose Apr 27, 2023
960ff01
Support --binarize again
norabelrose Apr 27, 2023
c9e62ea
Partial support for truthful_qa
norabelrose Apr 27, 2023
eb71a6c
Merge branch 'main' into template-filtering
norabelrose Apr 29, 2023
88bb15e
Merge remote-tracking branch 'origin/main' into template-filtering
norabelrose Apr 29, 2023
c648ff0
Remove crap
norabelrose Apr 29, 2023
ef12130
EleutherAI/truthful_qa_mc
norabelrose Apr 29, 2023
5d60ebd
Update templates
norabelrose Apr 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Tons of stuff, preparing for sciq_binary experiment
  • Loading branch information
norabelrose committed Apr 27, 2023
commit 69c2d557ddda039a0db364503f65425d2ffa7126
7 changes: 6 additions & 1 deletion elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path

import pandas as pd
import torch
Expand All @@ -15,7 +16,10 @@
class Eval(Run):
"""Full specification of a reporter evaluation run."""

source: str = field(default="", positional=True)
# Using None as a default here is a hack; we actually raise an error if it's not
# specified in __post_init__. TODO: Maybe this is an indication we should be using
# composition and not inheritance here?
source: Path | None = field(default=None, positional=True)
skip_supervised: bool = False

def __post_init__(self):
Expand All @@ -38,6 +42,7 @@ def apply_to_layer(
device = self.get_device(devices, world_size)
val_output = self.prepare_data(device, layer, "val")

assert self.source, "Must specify a source experiment."
experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
Expand Down
3 changes: 1 addition & 2 deletions elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .balanced_sampler import BalancedSampler, FewShotSampler
from .extraction import Extract, extract, extract_hiddens
from .generator import _GeneratorBuilder, _GeneratorConfig
from .prompt_loading import PromptConfig, load_prompts
from .prompt_loading import load_prompts

__all__ = [
"BalancedSampler",
Expand All @@ -11,6 +11,5 @@
"extract",
"_GeneratorConfig",
"_GeneratorBuilder",
"PromptConfig",
"load_prompts",
]
2 changes: 1 addition & 1 deletion elk/extraction/dataset_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def extract_dataset_name_and_config(dataset_config_str: str) -> tuple[str, str]:
"""Extract the dataset name and config name from the dataset prompt."""
ds_name, _, config_name = dataset_config_str.partition(" ")
ds_name, _, config_name = dataset_config_str.partition(":")
return ds_name, config_name


Expand Down
122 changes: 84 additions & 38 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Functions for extracting the hidden states of a model."""
import logging
import os
from copy import copy
from dataclasses import InitVar, dataclass
from itertools import islice
from dataclasses import InitVar, dataclass, replace
from itertools import islice, zip_longest
from typing import Any, Iterable, Literal
from warnings import filterwarnings

Expand Down Expand Up @@ -45,31 +44,70 @@
extract_dataset_name_and_config,
)
from .generator import _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts
from .prompt_loading import load_prompts


@dataclass
class Extract(Serializable):
"""
Args:
model: HuggingFace model string identifying the language model to extract
hidden states from.
prompts: The configuration for the prompt prompts.
layers: The layers to extract hidden states from.
layer_stride: Shortcut for setting `layers` to `range(0, num_layers, stride)`.
token_loc: The location of the token to extract hidden states from. Can be
either "first", "last", or "mean". Defaults to "last".
"""

prompts: PromptConfig
"""Config for extracting hidden states from a language model."""

model: str = field(positional=True)
"""HF model string identifying the language model to extract hidden states from."""

datasets: tuple[str, ...] = field(positional=True)
"""Names of HF datasets to use, e.g. `"super_glue:boolq"` or `"imdb"`"""

data_dirs: tuple[str, ...] = ()
"""Directory to use for caching the hiddens. Defaults to `HF_DATASETS_CACHE`."""

max_examples: tuple[int, int] = (1000, 1000)
"""Maximum number of examples to use from each split of the dataset."""

num_shots: int = 0
"""Number of examples for few-shot prompts. If zero, prompts are zero-shot."""

num_variants: int = -1
"""The number of prompt templates to use for each example. If -1, all available
templates are used."""

seed: int = 42
"""Seed to use for prompt randomization. Defaults to 42."""

layers: tuple[int, ...] = ()
"""Indices of layers to extract hidden states from. We follow the HF convention, so
0 is the embedding, and 1 is the output of the first transformer layer."""

layer_stride: InitVar[int] = 1
"""Shortcut for `layers = (0,) + tuple(range(1, num_layers + 1, stride))`."""

template_path: str | None = None
"""Path to pass into `DatasetTemplates`. By default we use the dataset name."""

token_loc: Literal["first", "last", "mean"] = "last"
"""The location of the token to extract hidden states from."""

use_encoder_states: bool = False
"""Whether to extract hidden states from the encoder instead of the decoder in the
case of encoder-decoder models."""

def __post_init__(self, layer_stride: int):
if len(self.max_examples) > 2:
raise ValueError(
"max_examples should be a list of length 0, 1, or 2,"
f"but got {len(self.max_examples)}"
)
if not self.max_examples:
self.max_examples = (int(1e100), int(1e100))

# Broadcast the dataset name to all data_dirs
if len(self.data_dirs) == 1:
self.data_dirs *= len(self.datasets)
elif self.data_dirs and len(self.data_dirs) != len(self.datasets):
raise ValueError(
"data_dirs should be a list of length 0, 1, or len(datasets),"
f" but got {len(self.data_dirs)}"
)

if self.layers and layer_stride > 1:
raise ValueError(
"Cannot use both --layers and --layer-stride. Please use only one."
Expand All @@ -87,14 +125,10 @@

def explode(self) -> list["Extract"]:
"""Explode this config into a list of configs, one for each layer."""
copies = []

for prompt_cfg in self.prompts.explode():
cfg = copy(self)
cfg.prompts = prompt_cfg
copies.append(cfg)

return copies
return [
replace(self, datasets=(ds,), data_dirs=(data_dir,) if data_dir else ())
for ds, data_dir in zip_longest(self.datasets, self.data_dirs)
]


@torch.inference_mode()
Expand All @@ -114,8 +148,7 @@
filterwarnings("ignore")
logging.disable(logging.CRITICAL)

p_cfg = cfg.prompts
ds_names = p_cfg.datasets
ds_names = cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

model = instantiate_model(
Expand All @@ -138,14 +171,15 @@
prompt_ds = load_prompts(
ds_names[0],
split_type=split_type,
template_path=cfg.template_path,
rank=rank,
world_size=world_size,
)

# Add one to the number of layers to account for the embedding layer
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers + 1))

global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1]
global_max_examples = cfg.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
# the last process gets the remainder (which is usually small)
Expand Down Expand Up @@ -277,18 +311,22 @@
model_cfg = AutoConfig.from_pretrained(cfg.model)

ds_name, config_name = extract_dataset_name_and_config(
dataset_config_str=cfg.prompts.datasets[0]
dataset_config_str=cfg.datasets[0]
)
info = get_dataset_config_info(ds_name, config_name or None)

prompter = DatasetTemplates(ds_name, config_name)
if not cfg.template_path:
prompter = DatasetTemplates(ds_name, config_name)
else:
prompter = DatasetTemplates(cfg.template_path)

ds_features = assert_type(Features, info.features)
label_col = prompter.label_column or infer_label_column(ds_features)
num_classes = len(prompter.label_choices) or infer_num_classes(
ds_features[label_col]
)

num_variants = cfg.prompts.num_variants
num_variants = cfg.num_variants
if num_variants < 0:
num_dropped = prompter.drop_non_mc_templates()
num_variants = len(prompter.templates)
Expand Down Expand Up @@ -340,19 +378,27 @@
info, features = hidden_features(cfg)

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
limit_list = cfg.prompts.max_examples
limits = cfg.max_examples
splits = assert_type(SplitDict, info.splits)

pretty_name = colorize(assert_type(str, cfg.datasets[0]), highlight_color)

Check failure on line 384 in elk/extraction/extraction.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, ubuntu-latest)

Argument of type "str" cannot be assigned to parameter "color" of type "Color" in function "colorize"   Type "str" cannot be assigned to type "Color"     "str" cannot be assigned to type "Literal['black']"     "str" cannot be assigned to type "Literal['red']"     "str" cannot be assigned to type "Literal['green']"     "str" cannot be assigned to type "Literal['yellow']"     "str" cannot be assigned to type "Literal['blue']"     "str" cannot be assigned to type "Literal['magenta']"     "str" cannot be assigned to type "Literal['cyan']" ... (reportGeneralTypeIssues)
if split_type is None:
train, val = select_train_val_splits(splits)
pretty_name = colorize(assert_type(str, info.builder_name), highlight_color)
print(f"{pretty_name}: using '{train}' for training and '{val}' for validation")

print(f"{pretty_name} using '{train}' for training and '{val}' for validation")
splits = SplitDict({train: splits[train], val: splits[val]})
split_types = ["train", "val"]
else:
# Remove the split we're not using
del limit_list[1 if split_type == "train" else 0]
limits = [limits[0]] if split_type == "train" else limits
split_name = select_split(splits, split_type)
splits = SplitDict({split_name: splits[split_name]})
split_types = [split_type]

if split_type == "train":
print(f"{pretty_name} using '{split_name}' for training")
else:
print(f"{pretty_name} using '{split_name}' for validation")

builders = {
split_name: _GeneratorBuilder(
Expand All @@ -362,18 +408,18 @@
split_name=split_name,
split_info=SplitInfo(
name=split_name,
num_examples=min(limit, v.num_examples) * len(cfg.prompts.datasets),
num_examples=min(limit, v.num_examples) * len(cfg.datasets),
dataset_name=v.dataset_name,
),
gen_kwargs=dict(
cfg=[cfg] * len(devices),
device=devices,
rank=list(range(len(devices))),
split_type=[split_name] * len(devices),
split_type=[ty] * len(devices),
world_size=[len(devices)] * len(devices),
),
)
for limit, (split_name, v) in zip(limit_list, splits.items())
for limit, (split_name, v), ty in zip(limits, splits.items(), split_types)
}
import multiprocess as mp

Expand All @@ -389,6 +435,6 @@

dataset_dict = DatasetDict(ds)
return DatasetDictWithName(
name=cfg.prompts.datasets[0],
name=cfg.datasets[0],
dataset=dataset_dict,
)
Loading
Loading