From bbee4890ea180c0e9aedcf9f27fd9f15bcb4a377 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 22 Apr 2023 04:26:28 +0000 Subject: [PATCH 01/17] Initial support for FEVER --- elk/extraction/balanced_sampler.py | 26 ++-- elk/extraction/extraction.py | 24 ++- elk/extraction/prompt_loading.py | 138 ++++++++---------- elk/promptsource/templates.py | 60 ++++---- .../templates/fever/v1.0/templates.yaml | 4 + elk/truncated_eigh.py | 12 +- elk/utils/data_utils.py | 33 ++++- elk/utils/typing.py | 6 +- tests/test_load_prompts.py | 1 + tests/test_samplers.py | 2 +- 10 files changed, 168 insertions(+), 138 deletions(-) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 3fd8a739..211f318d 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -1,8 +1,8 @@ from collections import deque -from dataclasses import dataclass, field +from dataclasses import InitVar, dataclass, field from itertools import cycle from random import Random -from typing import Iterable, Iterator, Optional +from typing import Hashable, Iterable, Iterator, Optional from datasets import Features, IterableDataset from torch.utils.data import IterableDataset as TorchIterableDataset @@ -26,25 +26,29 @@ class BalancedSampler(TorchIterableDataset): """ data: Iterable[dict] - num_classes: int + label_choices: InitVar[set[Hashable]] buffer_size: int = 1000 - buffers: dict[int, deque[dict]] = field(default_factory=dict, init=False) + buffers: dict[Hashable, deque[dict]] = field(default_factory=dict, init=False) label_col: str = "label" + strict: bool = True - def __post_init__(self): + def __post_init__(self, label_choices: set[Hashable]): # Initialize empty buffers self.buffers = { - label: deque(maxlen=self.buffer_size) for label in range(self.num_classes) + label: deque(maxlen=self.buffer_size) for label in label_choices } def __iter__(self): for sample in self.data: label = sample[self.label_col] - - # This whole class is a no-op if the label is not an integer - if not isinstance(label, int): - yield sample - continue + if label not in self.buffers: + if self.strict: + raise ValueError( + f"Expected label to be one of {self.buffers}, got {label}" + ) + else: + # Just skip this sample + continue # Add the sample to the buffer for its class label self.buffers[label].append(sample) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 39de74ac..872479b5 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -30,7 +30,6 @@ assert_type, float32_to_int16, infer_label_column, - infer_num_classes, instantiate_model, instantiate_tokenizer, is_autoregressive, @@ -128,10 +127,7 @@ def extract_hiddens( prompt_ds = load_prompts( ds_names[0], - label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None, - num_classes=p_cfg.num_classes, split_type=split_type, - stream=p_cfg.stream, rank=rank, world_size=world_size, ) @@ -202,15 +198,15 @@ def extract_hiddens( input_ids = torch.cat([input_ids, answer], dim=-1) if max_len := tokenizer.model_max_length: - input_ids = input_ids[..., -max_len:] + cur_len = input_ids.shape[-1] + input_ids = input_ids[..., -min(max_len, cur_len) :] # Make sure we only pass the arguments that the model expects inputs = dict(input_ids=input_ids) if is_enc_dec: inputs["labels"] = answer - with torch.autocast("cuda", enabled=torch.cuda.is_available()): - outputs = model(**inputs, output_hidden_states=True) + outputs = model(**inputs, output_hidden_states=True) # Compute the log probability of the answer tokens if available if has_lm_preds: @@ -303,17 +299,17 @@ def get_splits() -> SplitDict: ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") info = get_dataset_config_info(ds_name, config_name or None) + prompter = DatasetTemplates(ds_name, config_name) ds_features = assert_type(Features, info.features) - label_col = ( - cfg.prompts.label_columns[0] - if cfg.prompts.label_columns - else infer_label_column(ds_features) - ) - num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col]) + prompter.label_column or infer_label_column(ds_features) + num_classes = 2 # prompter.num_classes or infer_num_classes(ds_features[label_col]) + num_variants = cfg.prompts.num_variants if num_variants < 0: - prompter = DatasetTemplates(ds_name, config_name) + num_dropped = prompter.drop_non_mc_templates() num_variants = len(prompter.templates) + if num_dropped: + print(f"Dropping {num_dropped} non-multiple choice templates") layer_cols = { f"hidden_{layer}": Array3D( diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 456c13c6..d0055bee 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -3,21 +3,15 @@ from dataclasses import dataclass from itertools import zip_longest from random import Random -from typing import Any, Iterator, Literal, Optional +from typing import Any, Iterator, Literal -from datasets import ( - Dataset, - Features, - load_dataset, -) -from datasets.distributed import split_dataset_by_node +from datasets import ClassLabel, Dataset, Value, load_dataset from simple_parsing.helpers import Serializable, field from ..promptsource import DatasetTemplates from ..utils import ( assert_type, infer_label_column, - infer_num_classes, select_train_val_splits, ) from .balanced_sampler import BalancedSampler, FewShotSampler @@ -31,9 +25,6 @@ class PromptConfig(Serializable): `"super_glue boolq"` or `"imdb"`. data_dir: The directory to use for caching the dataset. Defaults to `~/.cache/huggingface/datasets`. - label_column: The column containing the labels. By default, we infer this from - the datatypes of the columns in the dataset; if there is only one column - with a `ClassLabel` datatype, we use that. max_examples: The maximum number of examples to use from the val dataset. If a single number, use at most that many examples for each split. If a list of length 2, use the first element for the train split and the second for @@ -43,18 +34,14 @@ class PromptConfig(Serializable): num_variants: The number of prompt templates to apply to each predicate upon call to __getitem__. Use -1 to apply all available templates. Defaults to 1. seed: The seed to use for prompt randomization. Defaults to 42. - stream: Whether to stream the dataset from the Internet. Defaults to False. """ datasets: list[str] = field(positional=True) data_dirs: list[str] = field(default_factory=list) - label_columns: list[str] = field(default_factory=list) max_examples: list[int] = field(default_factory=lambda: [1000, 1000]) - num_classes: int = 0 num_shots: int = 0 num_variants: int = -1 seed: int = 42 - stream: bool = False def __post_init__(self): if len(self.max_examples) > 2: @@ -69,7 +56,7 @@ def __post_init__(self): if len(self.max_examples) == 1: self.max_examples *= 2 - # Broadcast the dataset name to all data_dirs and label_columns + # Broadcast the dataset name to all data_dirs if len(self.data_dirs) == 1: self.data_dirs *= len(self.datasets) elif self.data_dirs and len(self.data_dirs) != len(self.datasets): @@ -78,25 +65,14 @@ def __post_init__(self): f" but got {len(self.data_dirs)}" ) - if len(self.label_columns) == 1: - self.label_columns *= len(self.datasets) - elif self.label_columns and len(self.label_columns) != len(self.datasets): - raise ValueError( - "label_columns should be a list of length 0, 1, or len(datasets)," - f" but got {len(self.label_columns)}" - ) - def explode(self) -> list["PromptConfig"]: """Explode the config into a list of configs, one for each dataset.""" copies = [] - for ds, data_dir, col in zip_longest( - self.datasets, self.data_dirs, self.label_columns - ): + for ds, data_dir in zip_longest(self.datasets, self.data_dirs): copy = deepcopy(self) copy.datasets = [ds] copy.data_dirs = [data_dir] if data_dir else [] - copy.label_columns = [col] if col else [] copies.append(copy) return copies @@ -105,13 +81,10 @@ def explode(self) -> list["PromptConfig"]: def load_prompts( ds_string: str, *, - label_column: Optional[str] = None, - num_classes: int = 0, num_shots: int = 0, num_variants: int = -1, seed: int = 42, split_type: Literal["train", "val"] = "train", - stream: bool = False, rank: int = 0, world_size: int = 1, ) -> Iterator[dict]: @@ -120,15 +93,10 @@ def load_prompts( Args: ds_string: Space-delimited name of the HuggingFace dataset to use, e.g. `"super_glue boolq"` or `"imdb"`. - label_column: The column containing the labels. By default, we infer this from - the datatypes of the columns in the dataset. - num_classes: The number of classes in the dataset. If zero, we infer this from - the datatypes of the columns in the dataset. num_shots: The number of examples to use in few-shot prompts. If zero, prompts are zero-shot. seed: The seed to use for prompt randomization. split_type: Whether to use the train or val split of the dataset. - stream: Whether to stream the dataset from the Internet. Defaults to False. rank: The rank of the current process. Defaults to 0. world_size: The number of processes. Defaults to 1. @@ -137,25 +105,18 @@ def load_prompts( """ ds_name, _, config_name = ds_string.partition(" ") prompter = DatasetTemplates(ds_name, config_name) + prompter.drop_non_mc_templates() - ds_dict = assert_type( - dict, load_dataset(ds_name, config_name or None, streaming=stream) - ) + ds_dict = assert_type(dict, load_dataset(ds_name, config_name or None)) train_name, val_name = select_train_val_splits(ds_dict) split_name = val_name if split_type == "val" else train_name ds = ds_dict[split_name].shuffle(seed=seed) train_ds = ds_dict[train_name].shuffle(seed=seed) - if not stream: - ds = assert_type(Dataset, ds) - if world_size > 1: - ds = ds.shard(world_size, rank) - - ds = ds.to_iterable_dataset().cast(ds.features) - elif world_size > 1: - # This prints to stdout which is slightly annoying - ds = split_dataset_by_node(dataset=ds, rank=rank, world_size=world_size) + ds = assert_type(Dataset, ds) + if world_size > 1: + ds = ds.shard(world_size, rank) num_templates = len(prompter.templates) num_variants = ( @@ -165,10 +126,44 @@ def load_prompts( if rank == 0: print(f"Using {num_variants} variants of each prompt") - label_column = label_column or infer_label_column(ds.features) - num_classes = num_classes or infer_num_classes(ds.features[label_column]) - rng = Random(seed) + # Which classes are actually present in this split of the dataset? + # This is shockingly fast since it uses an optimized Apache Arrow primitive. + label_column = prompter.label_column or infer_label_column(ds.features) + observed_labels = set(ds.unique(label_column)) + + # Now sanity check that the observed classes match the expected classes. This can + # sometimes fail if we picked an unlabeled split (e.g. everything is -1) + label_feature = ds.features[label_column] + if isinstance(label_feature, ClassLabel): + label_choices = {label_feature.str2int(label) for label in label_feature.names} + elif isinstance(label_feature, Value) and label_feature.dtype == "bool": + label_choices = {False, True} + else: + # We just have to assume that the observed labels are right + label_choices = observed_labels + + if observed_labels != label_choices: + raise ValueError( + f"Observed labels {observed_labels} in split '{split_name}' do not match " + f"expected labels {label_choices} from the dataset features." + ) + if prompt_choices := prompter.label_choices: + # The observed labels should be a superset of the prompt choices + if not (observed_labels >= set(prompt_choices)): + raise ValueError( + f"Observed labels {observed_labels} in split '{split_name}' do not " + f"match the prompt choices {prompt_choices}." + ) + + sorted_labels = prompt_choices + else: + # Impose a canonical order on the label choices. Theoretically the label column + # may be of a type that doesn't support comparison (so Pylance complains), but + # we'll just let it raise an exception if that happens. + sorted_labels = sorted(label_choices) # type: ignore[arg-type] + + rng = Random(seed) if num_shots > 0: fewshot = FewShotSampler( train_ds, # TODO: not iterator @@ -179,15 +174,17 @@ def load_prompts( else: fewshot_iter = None - # Remove everything except the label column - extra_cols = list(assert_type(Features, ds.features)) - extra_cols.remove(label_column) + ds = ds.to_iterable_dataset() + if rank == 0: + print(f"Label choices: {sorted_labels}") - for example in BalancedSampler(ds, num_classes, label_col=label_column): + for example in BalancedSampler( + ds, set(sorted_labels), label_col=label_column, strict=False + ): yield _convert_to_prompts( example, label_column=label_column, - num_classes=num_classes, + label_choices=sorted_labels, # type: ignore[arg-type] num_variants=num_variants, prompter=prompter, rng=rng, @@ -199,13 +196,12 @@ def _convert_to_prompts( example: dict[str, Any], prompter: DatasetTemplates, label_column: str, - num_classes: int, + label_choices: list[bool | int | str], num_variants: int, rng: Random, - fewshot_iter: Optional[Iterator[list[dict]]] = None, + fewshot_iter: Iterator[list[dict]] | None = None, ) -> dict[str, Any]: """Prompt-generating function to pass to `IterableDataset.map`.""" - labels_are_strings = isinstance(example[label_column], str) prompts = [] templates = list(prompter.templates.values()) if num_variants < len(templates): @@ -218,21 +214,14 @@ def qa_cat(q: str, a: str) -> str: # For sanity checking that prompts are unique prompt_counter = Counter() - label_indices = set() + label = example[label_column] for template in templates: choices = [] - string_choices = template.get_answer_choices_list(example) - - label = example[label_column] - label_indices.add(string_choices.index(label) if labels_are_strings else label) - for answer_idx in range(num_classes): + for pseudo_label in label_choices: fake_example = example.copy() - if labels_are_strings: - fake_example[label_column] = string_choices[answer_idx] - else: - fake_example[label_column] = answer_idx + fake_example[label_column] = pseudo_label q, a = template.apply(fake_example) prompt_counter[(q, a)] += 1 @@ -261,14 +250,11 @@ def qa_cat(q: str, a: str) -> str: if dup_count > 1: raise ValueError(f'Prompt duplicated {dup_count} times! "{maybe_dup}"') - # Sanity check: label should be the same across all variants - if len(label_indices) > 1: - raise ValueError( - f"Label index should be the same all variants, but got {label_indices}" - ) - + # Our reporter training and evaluation code assumes that the labels are integers. + # If they're not, we need to convert them with index(). label_choices is guaranteed + # to be sorted (see above). return dict( - label=label_indices.pop(), + label=label_choices.index(label), prompts=prompts, template_names=[template.name for template in templates], ) diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index ea4e9196..920f8ae1 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -1,4 +1,3 @@ -import logging import os import random import uuid @@ -382,18 +381,45 @@ class DatasetTemplates: TEMPLATES_KEY = "templates" DATASET_KEY = "dataset" SUBSET_KEY = "subset" + LABEL_COLUMN_KEY = "label_column" + LABEL_CHOICES_KEY = "label_choices" TEMPLATE_FILENAME = "templates.yaml" - def __init__(self, dataset_name: str, subset_name: Optional[str] = None): + label_column: str | None + label_choices: list[str] | None + + def __init__(self, dataset_name: str, subset_name: str | None = None): self.dataset_name = dataset_name self.subset_name = subset_name - # dictionary is keyed by template name. - self.templates: dict = self.read_from_file() + + with open(self.yaml_path, "r") as f: + yaml_dict = yaml.load(f, Loader=yaml.FullLoader) + + # Required field; contains all the templates keyed by ID + self.templates = yaml_dict[self.TEMPLATES_KEY] + + # Optional fields; may be None + self.label_column = yaml_dict.get(self.LABEL_COLUMN_KEY) + self.label_choices = yaml_dict.get(self.LABEL_CHOICES_KEY) # Mapping from template name to template id self.name_to_id_mapping = {} self.sync_mapping() + def drop_non_mc_templates(self) -> int: + """Drop all templates that aren't multiple choice, return the number dropped""" + mc_templates = { + k: v for k, v in self.templates.items() if v.get_fixed_answer_choices_list() + } + if not mc_templates: + raise ValueError("No multiple choice templates found") + + num_dropped = len(self.templates) - len(mc_templates) + self.templates = mc_templates + self.sync_mapping() + + return num_dropped + def sync_mapping(self) -> None: """ Re-compute the name_to_id_mapping to ensure it is in sync with self.templates @@ -420,7 +446,11 @@ def folder_path(self) -> str: @property def yaml_path(self) -> str: - return os.path.join(self.folder_path, self.TEMPLATE_FILENAME) + path = os.path.join(self.folder_path, self.TEMPLATE_FILENAME) + if not os.path.exists(path): + raise ValueError(f"Expected prompt templates to exist at {path}") + + return path def format_for_dump(self) -> dict: """ @@ -434,26 +464,6 @@ def format_for_dump(self) -> dict: formatted_dict[self.SUBSET_KEY] = self.subset_name return formatted_dict - def read_from_file(self) -> dict: - """ - Reads a file containing a prompt collection. - """ - - if not os.path.exists(self.yaml_path): - dataset_name = ( - f"{self.dataset_name} {self.subset_name}" - if self.subset_name - else self.dataset_name - ) - logging.warning( - f"Tried instantiating `DatasetTemplates` for {dataset_name}, but no " - f"prompts found. Please ignore this warning if you are creating new " - f"prompts for this dataset." - ) - return {} - yaml_dict = yaml.load(open(self.yaml_path, "r"), Loader=yaml.FullLoader) - return yaml_dict[self.TEMPLATES_KEY] - def write_to_file(self) -> None: """ Writes to a file with the current prompt collection. diff --git a/elk/promptsource/templates/fever/v1.0/templates.yaml b/elk/promptsource/templates/fever/v1.0/templates.yaml index d25b27a3..b66be449 100644 --- a/elk/promptsource/templates/fever/v1.0/templates.yaml +++ b/elk/promptsource/templates/fever/v1.0/templates.yaml @@ -1,5 +1,9 @@ dataset: fever subset: v1.0 +label_column: label +label_choices: + - REFUTES + - SUPPORTS templates: 0870481e-e5d1-43a1-821e-b11c6bfd2483: !Template answer_choices: Yes|||No|||Not sure diff --git a/elk/truncated_eigh.py b/elk/truncated_eigh.py index 1cca6c83..bf92707a 100644 --- a/elk/truncated_eigh.py +++ b/elk/truncated_eigh.py @@ -16,7 +16,14 @@ class Eigendecomposition(NamedTuple): eigenvectors: Tensor -@torch.autocast("cuda", enabled=torch.cuda.is_available()) +# Do most of the computation in bfloat16 if available; this can be substantially faster +# than fp32. We can't use fp16 because the dynamic range is too small for the Lanczos +# iteration to converge in many cases. +@torch.autocast( + "cuda", + dtype=torch.bfloat16, + enabled=torch.cuda.is_available() and torch.cuda.is_bf16_supported(), +) def truncated_eigh( A: Tensor, k: int = 1, @@ -42,8 +49,7 @@ def truncated_eigh( max_iter (int, optional): The maximum number of iterations to perform. ncv (int, optional): The number of Lanczos vectors generated. Must be greater than k and smaller than n - 1. - tol (float, optional): The tolerance for the residual. Defaults to the machine - precision of `A.dtype` or 1e-4, whichever is larger. + tol (float, optional): The tolerance for the residual. Defaults to 1e-3. seed (int, optional): The random seed to use for the starting vector. which (str, optional): Which k eigenvalues and eigenvectors to compute. Must be one of 'LA', or 'SA'. diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 34229342..181ed974 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -66,22 +66,41 @@ def select_train_val_splits(raw_splits: Iterable[str]) -> tuple[str, str]: def infer_label_column(features: Features) -> str: - """Return the unique `ClassLabel` column in a `Dataset`. + """Return the unique `ClassLabel` column, or "label" if it's of a suitable dtype. Returns: The name of the unique label column. Raises: - ValueError: If there are no `ClassLabel` columns, or if there are multiple. + ValueError: If it's unclear what the label column is. """ label_cols = [ col for col, dtype in features.items() if isinstance(dtype, ClassLabel) ] if not label_cols: - raise ValueError("Dataset has no label column") + # One more heuristic: if there's a column just named "label" with a reasonable + # dtype, use that. + col = features.get("label") + if not col: + raise ValueError( + "None of the columns in the dataset are obviously the label column; " + "please specify label_column in the prompt template yaml file." + ) + + import pyarrow as pa + + if pa.types.is_integer(col.pa_type) or col.dtype in ("bool", "string"): + return "label" + else: + # We don't support floats, timestamps, bytes, containers, etc. + raise ValueError( + f"Column 'label' has unsupported dtype {col.dtype}; please specify " + "a different label_column in the prompt template yaml file." + ) + elif len(label_cols) > 1: raise ValueError( - f"Dataset has multiple label columns {label_cols}; specify " - f"label_column to disambiguate" + f"Dataset has multiple label columns {label_cols}; specify label_column " + "in the prompt template yaml to disambiguate" ) else: return assert_type(str, label_cols[0]) @@ -102,8 +121,8 @@ def infer_num_classes(label_feature: Any) -> int: return 2 else: raise ValueError( - f"Can't infer number of classes from label column " - f"of type {label_feature}" + f"Can't infer number of classes from label column of type {label_feature}. " + f"Please update the num_classes field in the prompt template yaml file." ) diff --git a/elk/utils/typing.py b/elk/utils/typing.py index 1d38040e..5624dbcc 100644 --- a/elk/utils/typing.py +++ b/elk/utils/typing.py @@ -15,7 +15,11 @@ def assert_type(typ: Type[T], obj: Any) -> T: def float32_to_int16(x: torch.Tensor) -> torch.Tensor: """Converts float32 to float16, then reinterprets as int16.""" - return x.type(torch.float16).view(torch.int16) + fp16 = x.type(torch.float16) + if not fp16.isfinite().all(): + raise ValueError("Tensor contains non-finite values!") + + return fp16.view(torch.int16) def int16_to_float32(x: torch.Tensor) -> torch.Tensor: diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py index 2e03c379..d8c065cf 100644 --- a/tests/test_load_prompts.py +++ b/tests/test_load_prompts.py @@ -16,6 +16,7 @@ def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): ds_name, _, config_name = ds_string.partition(" ") prompter = DatasetTemplates(ds_name, config_name or None) + prompter.drop_non_mc_templates() limit = cfg.max_examples[0 if split_type == "train" else 1] for record in islice(prompt_ds, limit): diff --git a/tests/test_samplers.py b/tests/test_samplers.py index c5a142c1..9a1f694e 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -41,7 +41,7 @@ def test_output_is_roughly_balanced(): ) col = infer_label_column(dataset.features) - reservoir = BalancedSampler(dataset, 2) + reservoir = BalancedSampler(dataset, {0, 1}) # Count the number of samples for each label counter = Counter() From 5ba1dddc738da8651474eeaa42bc45d154b1e9fa Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 22 Apr 2023 04:52:08 +0000 Subject: [PATCH 02/17] Start saving and fitting a reporter to the input embeddings --- elk/extraction/extraction.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 39de74ac..7102fb3e 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -74,7 +74,9 @@ def __post_init__(self, layer_stride: int): config = assert_type( PretrainedConfig, AutoConfig.from_pretrained(self.model) ) - self.layers = tuple(range(0, config.num_hidden_layers, layer_stride)) + # Note that we always include 0 which is the embedding layer + layer_range = range(0, config.num_hidden_layers, layer_stride) + self.layers = (0,) + tuple(layer_range) def explode(self) -> list["Extract"]: """Explode this config into a list of configs, one for each layer.""" @@ -136,8 +138,8 @@ def extract_hiddens( world_size=world_size, ) - # Iterating over questions - layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) + # Add one to the number of layers to account for the embedding layer + layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers + 1)) global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1] # break `max_examples` among the processes roughly equally @@ -229,9 +231,6 @@ def extract_hiddens( hiddens = ( outputs.get("decoder_hidden_states") or outputs["hidden_states"] ) - # First element of list is the input embeddings - hiddens = hiddens[1:] - # Throw out layers we don't care about hiddens = [hiddens[i] for i in layer_indices] @@ -320,7 +319,8 @@ def get_splits() -> SplitDict: dtype="int16", shape=(num_variants, num_classes, model_cfg.hidden_size), ) - for layer in cfg.layers or range(model_cfg.num_hidden_layers) + # Add 1 to include the embedding layer + for layer in cfg.layers or range(model_cfg.num_hidden_layers + 1) } other_cols = { "variant_ids": Sequence( From 51ba54f89f775d1af68695f123de41122c5e8593 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 22 Apr 2023 05:36:15 +0000 Subject: [PATCH 03/17] Rename layer 0 to 'input' to make it more clear --- elk/run.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/elk/run.py b/elk/run.py index e246860a..218e708d 100644 --- a/elk/run.py +++ b/elk/run.py @@ -153,6 +153,9 @@ def apply_to_layers( # Make sure the CSV is written even if we crash or get interrupted if df_buf: df = pd.concat(df_buf).sort_values(by="layer") + + # Rename layer 0 to "input" to make it more clear + df.loc[0, "layer"] = "input" df.round(4).to_csv(f, index=False) if self.cfg.debug: save_debug_log(self.datasets, self.out_dir) From 544b485f9956c51ded74e7369ecf54e94caa8a93 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 22 Apr 2023 05:48:21 +0000 Subject: [PATCH 04/17] Actually rename layer 0 correctly --- elk/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/run.py b/elk/run.py index 218e708d..34afc6f5 100644 --- a/elk/run.py +++ b/elk/run.py @@ -155,7 +155,7 @@ def apply_to_layers( df = pd.concat(df_buf).sort_values(by="layer") # Rename layer 0 to "input" to make it more clear - df.loc[0, "layer"] = "input" + df["layer"].replace(0, "input", inplace=True) df.round(4).to_csv(f, index=False) if self.cfg.debug: save_debug_log(self.datasets, self.out_dir) From 43da44ef9bcc028e76833f9faac2ccc5d8f77b45 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 22 Apr 2023 05:51:58 +0000 Subject: [PATCH 05/17] Handle layer_stride correctly --- elk/extraction/extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 7102fb3e..3d2939bb 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -75,7 +75,7 @@ def __post_init__(self, layer_stride: int): PretrainedConfig, AutoConfig.from_pretrained(self.model) ) # Note that we always include 0 which is the embedding layer - layer_range = range(0, config.num_hidden_layers, layer_stride) + layer_range = range(1, config.num_hidden_layers + 1, layer_stride) self.layers = (0,) + tuple(layer_range) def explode(self) -> list["Extract"]: From 756fa532c48731d951e9791d8eba761f564043b2 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 22 Apr 2023 09:47:04 +0000 Subject: [PATCH 06/17] label_choices --- elk/extraction/extraction.py | 13 ++++++++----- elk/promptsource/templates.py | 4 ++-- elk/promptsource/templates/glue/mnli/templates.yaml | 1 + 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 51ea1bf6..524beee2 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -30,6 +30,7 @@ assert_type, float32_to_int16, infer_label_column, + infer_num_classes, instantiate_model, instantiate_tokenizer, is_autoregressive, @@ -216,13 +217,13 @@ def extract_hiddens( log_p = outputs.logits[..., -answer_len:, :].log_softmax(dim=-1) tokens = answer[..., None] - lm_logits[i, j] = log_p.gather(-1, tokens).sum() + lm_logits[i, j] = log_p.gather(-1, tokens).mean() elif isinstance(outputs, Seq2SeqLMOutput): # The cross entropy loss is averaged over tokens, so we need to # multiply by the length to get the total log probability. - length = encoding.labels.shape[-1] - lm_logits[i, j] = -assert_type(Tensor, outputs.loss) * length + # length = encoding.labels.shape[-1] + lm_logits[i, j] = -assert_type(Tensor, outputs.loss) # * length hiddens = ( outputs.get("decoder_hidden_states") or outputs["hidden_states"] @@ -300,8 +301,10 @@ def get_splits() -> SplitDict: prompter = DatasetTemplates(ds_name, config_name) ds_features = assert_type(Features, info.features) - prompter.label_column or infer_label_column(ds_features) - num_classes = 2 # prompter.num_classes or infer_num_classes(ds_features[label_col]) + label_col = prompter.label_column or infer_label_column(ds_features) + num_classes = len(prompter.label_choices) or infer_num_classes( + ds_features[label_col] + ) num_variants = cfg.prompts.num_variants if num_variants < 0: diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index 920f8ae1..f70c5520 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -386,7 +386,7 @@ class DatasetTemplates: TEMPLATE_FILENAME = "templates.yaml" label_column: str | None - label_choices: list[str] | None + label_choices: list[str] def __init__(self, dataset_name: str, subset_name: str | None = None): self.dataset_name = dataset_name @@ -400,7 +400,7 @@ def __init__(self, dataset_name: str, subset_name: str | None = None): # Optional fields; may be None self.label_column = yaml_dict.get(self.LABEL_COLUMN_KEY) - self.label_choices = yaml_dict.get(self.LABEL_CHOICES_KEY) + self.label_choices = yaml_dict.get(self.LABEL_CHOICES_KEY, []) # Mapping from template name to template id self.name_to_id_mapping = {} diff --git a/elk/promptsource/templates/glue/mnli/templates.yaml b/elk/promptsource/templates/glue/mnli/templates.yaml index 5f110193..be747662 100644 --- a/elk/promptsource/templates/glue/mnli/templates.yaml +++ b/elk/promptsource/templates/glue/mnli/templates.yaml @@ -1,5 +1,6 @@ dataset: glue subset: mnli +label_choices: [0, 2] templates: 02b4c44e-52cb-417b-b069-5d334b1f1a91: !Template answer_choices: Always ||| Sometimes ||| Never From 93b7ae0d4a1a0eee1448d0a84ef908e455663a9f Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 22 Apr 2023 18:52:46 +0000 Subject: [PATCH 07/17] Clean up train and eval commands; do transfer in sweep --- elk/evaluation/evaluate.py | 70 +++++------------------------ elk/extraction/extraction.py | 9 ++-- elk/run.py | 71 +++++++++++++++++++---------- elk/training/__init__.py | 3 +- elk/training/eigen_reporter.py | 2 +- elk/training/reporter.py | 21 +-------- elk/training/sweep.py | 38 +++++++++++++++- elk/training/train.py | 81 +++++++++------------------------- elk/truncated_eigh.py | 1 - elk/utils/__init__.py | 2 + elk/utils/pretty.py | 24 ++++++++++ 11 files changed, 151 insertions(+), 171 deletions(-) create mode 100644 elk/utils/pretty.py diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 0269053c..3329d184 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,70 +1,36 @@ from dataclasses import dataclass -from functools import partial -from pathlib import Path -from typing import Callable import pandas as pd import torch -from simple_parsing.helpers import Serializable, field +from simple_parsing.helpers import field -from ..extraction.extraction import Extract from ..files import elk_reporter_dir from ..metrics import evaluate_preds from ..run import Run from ..training import Reporter -from ..utils import select_usable_devices @dataclass -class Eval(Serializable): - """ - Full specification of a reporter evaluation run. - - Args: - data: Config specifying hidden states on which the reporter will be evaluated. - source: The name of the source run directory - which contains the reporters directory. - normalization: The normalization method to use. Defaults to "meanonly". See - `elk.training.preprocessing.normalize()` for details. - num_gpus: The number of GPUs to use. Defaults to -1, which means - "use all available GPUs". - skip_supervised: Whether to skip evaluation of the supervised classifier. - debug: When in debug mode, a useful log file is saved to the memorably-named - output directory. Defaults to False. - """ - - data: Extract - source: str = field(positional=True) - - concatenated_layer_offset: int = 0 - debug: bool = False - min_gpu_mem: int | None = None - num_gpus: int = -1 - out_dir: Path | None = None +class Eval(Run): + """Full specification of a reporter evaluation run.""" + + source: str = field(default="", positional=True) skip_supervised: bool = False - disable_cache: bool = field(default=False, to_dict=False) + def __post_init__(self): + assert self.source, "Must specify a source experiment." - def execute(self): transfer_dir = elk_reporter_dir() / self.source / "transfer_eval" + self.out_dir = transfer_dir / "+".join(self.data.prompts.datasets) - for dataset in self.data.prompts.datasets: - run = Evaluate(cfg=self, out_dir=transfer_dir / dataset) - run.evaluate() - - -@dataclass -class Evaluate(Run): - cfg: Eval - - def evaluate_reporter( - self, layer: int, devices: list[str], world_size: int = 1 + def apply_to_layer( + self, layer: int, devices: list[str], world_size: int ) -> pd.DataFrame: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) val_output = self.prepare_data(device, layer, "val") - experiment_dir = elk_reporter_dir() / self.cfg.source + experiment_dir = elk_reporter_dir() / self.source reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt" reporter: Reporter = torch.load(reporter_path, map_location=device) @@ -81,7 +47,7 @@ def evaluate_reporter( } lr_dir = experiment_dir / "lr_models" - if not self.cfg.skip_supervised and lr_dir.exists(): + if not self.skip_supervised and lr_dir.exists(): with open(lr_dir / f"layer_{layer}.pt", "rb") as f: lr_model = torch.load(f, map_location=device).eval() @@ -91,15 +57,3 @@ def evaluate_reporter( row_buf.append(stats_row) return pd.DataFrame.from_records(row_buf) - - def evaluate(self): - """Evaluate the reporter on all layers.""" - devices = select_usable_devices( - self.cfg.num_gpus, min_memory=self.cfg.min_gpu_mem - ) - - num_devices = len(devices) - func: Callable[[int], pd.DataFrame] = partial( - self.evaluate_reporter, devices=devices, world_size=num_devices - ) - self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 39de74ac..3eab4fbf 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -28,6 +28,7 @@ from ..promptsource import DatasetTemplates from ..utils import ( assert_type, + colorize, float32_to_int16, infer_label_column, infer_num_classes, @@ -271,6 +272,7 @@ def extract( cfg: "Extract", *, disable_cache: bool = False, + highlight_color: str = "cyan", num_gpus: int = -1, min_gpu_mem: int | None = None, ) -> DatasetDict: @@ -279,10 +281,11 @@ def extract( def get_splits() -> SplitDict: available_splits = assert_type(SplitDict, info.splits) train_name, val_name = select_train_val_splits(available_splits) + + pretty_name = colorize(assert_type(str, info.builder_name), highlight_color) print( - # Cyan color for dataset name - f"\033[36m{info.builder_name}\033[0m: using '{train_name}' for training and" - f" '{val_name}' for validation" + f"{pretty_name}: using '{train_name}' for training " + f"and '{val_name}' for validation" ) limit_list = cfg.prompts.max_examples diff --git a/elk/run.py b/elk/run.py index e246860a..0003131b 100644 --- a/elk/run.py +++ b/elk/run.py @@ -1,9 +1,10 @@ import os import random -from abc import ABC -from dataclasses import dataclass, field +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Callable, Literal, Union +from typing import Callable, Literal import numpy as np import pandas as pd @@ -11,11 +12,12 @@ import torch.multiprocessing as mp import yaml from datasets import DatasetDict +from simple_parsing.helpers import Serializable, field from torch import Tensor from tqdm import tqdm from .debug_logging import save_debug_log -from .extraction import extract +from .extraction import Extract, extract from .files import elk_reporter_dir, memorably_named_dir from .utils import ( assert_type, @@ -23,35 +25,44 @@ get_layers, int16_to_float32, select_train_val_splits, + select_usable_devices, ) -if TYPE_CHECKING: - from .evaluation.evaluate import Eval - from .training.train import Elicit - @dataclass -class Run(ABC): - cfg: Union["Elicit", "Eval"] +class Run(ABC, Serializable): + data: Extract + out_dir: Path | None = None + """Directory to save results to. If None, a directory will be created + automatically.""" + + datasets: list[DatasetDict] = field(default_factory=list, init=False) + """Datasets containing hidden states and labels for each layer.""" + + concatenated_layer_offset: int = 0 + debug: bool = False + min_gpu_mem: int | None = None + num_gpus: int = -1 out_dir: Path | None = None - datasets: list[DatasetDict] = field(init=False) + disable_cache: bool = field(default=False, to_dict=False) - def __post_init__(self): + def execute(self, highlight_color: str = "cyan"): self.datasets = [ extract( cfg, - disable_cache=self.cfg.disable_cache, - num_gpus=self.cfg.num_gpus, - min_gpu_mem=self.cfg.min_gpu_mem, + disable_cache=self.disable_cache, + highlight_color=highlight_color, + num_gpus=self.num_gpus, + min_gpu_mem=self.min_gpu_mem, ) - for cfg in self.cfg.data.explode() + for cfg in self.data.explode() ] if self.out_dir is None: # Save in a memorably-named directory inside of # ELK_REPORTER_DIR// - ds_name = ", ".join(self.cfg.data.prompts.datasets) - root = elk_reporter_dir() / self.cfg.data.model / ds_name + ds_name = ", ".join(self.data.prompts.datasets) + root = elk_reporter_dir() / self.data.model / ds_name self.out_dir = memorably_named_dir(root) @@ -61,7 +72,7 @@ def __post_init__(self): path = self.out_dir / "cfg.yaml" with open(path, "w") as f: - self.cfg.dump_yaml(f) + self.dump_yaml(f) path = self.out_dir / "fingerprints.yaml" with open(path, "w") as meta_f: @@ -75,6 +86,19 @@ def __post_init__(self): meta_f, ) + devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem) + num_devices = len(devices) + func: Callable[[int], pd.DataFrame] = partial( + self.apply_to_layer, devices=devices, world_size=num_devices + ) + self.apply_to_layers(func=func, num_devices=num_devices) + + @abstractmethod + def apply_to_layer( + self, layer: int, devices: list[str], world_size: int + ) -> pd.DataFrame: + """Train or eval a reporter on a single layer.""" + def make_reproducible(self, seed: int): """Make the run reproducible by setting the random seed.""" @@ -114,8 +138,8 @@ def prepare_data( def concatenate(self, layers): """Concatenate hidden states from a previous layer.""" - for layer in range(self.cfg.concatenated_layer_offset, len(layers)): - layers[layer] += [layers[layer][0] - self.cfg.concatenated_layer_offset] + for layer in range(self.concatenated_layer_offset, len(layers)): + layers[layer] += [layers[layer][0] - self.concatenated_layer_offset] return layers @@ -137,10 +161,9 @@ def apply_to_layers( layers, *rest = [get_layers(ds) for ds in self.datasets] assert all(x == layers for x in rest), "All datasets must have the same layers" - if self.cfg.concatenated_layer_offset > 0: + if self.concatenated_layer_offset > 0: layers = self.concatenate(layers) - # Should we write to different CSV files for elicit vs eval? ctx = mp.get_context("spawn") with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: mapper = pool.imap_unordered if num_devices > 1 else map @@ -154,5 +177,5 @@ def apply_to_layers( if df_buf: df = pd.concat(df_buf).sort_values(by="layer") df.round(4).to_csv(f, index=False) - if self.cfg.debug: + if self.debug: save_debug_log(self.datasets, self.out_dir) diff --git a/elk/training/__init__.py b/elk/training/__init__.py index ce6e4d48..635644c1 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -2,7 +2,7 @@ from .classifier import Classifier from .eigen_reporter import EigenReporter, EigenReporterConfig from .normalizer import Normalizer -from .reporter import OptimConfig, Reporter, ReporterConfig +from .reporter import Reporter, ReporterConfig __all__ = [ "CcsReporter", @@ -11,7 +11,6 @@ "EigenReporter", "EigenReporterConfig", "Normalizer", - "OptimConfig", "Reporter", "ReporterConfig", ] diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 9e1b4f37..fc122b45 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -213,7 +213,7 @@ def fit_streaming(self, truncated: bool = False) -> float: ) if truncated: - L, Q = truncated_eigh(A, k=self.config.num_heads) + L, Q = truncated_eigh(A, k=self.config.num_heads, seed=self.config.seed) else: try: L, Q = torch.linalg.eigh(A) diff --git a/elk/training/reporter.py b/elk/training/reporter.py index c10b9562..e6e84f96 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Literal, Optional +from typing import Optional import torch import torch.nn as nn @@ -21,25 +21,6 @@ class ReporterConfig(Serializable): seed: int = 42 -@dataclass -class OptimConfig(Serializable): - """ - Args: - lr: The learning rate to use. Ignored when `optimizer` is `"lbfgs"`. - Defaults to 1e-2. - num_epochs: The number of epochs to train for. Defaults to 1000. - num_tries: The number of times to try training the reporter. Defaults to 10. - optimizer: The optimizer to use. Defaults to "adam". - weight_decay: The weight decay or L2 penalty to use. Defaults to 0.01. - """ - - lr: float = 1e-2 - num_epochs: int = 1000 - num_tries: int = 10 - optimizer: Literal["adam", "lbfgs"] = "lbfgs" - weight_decay: float = 0.01 - - class Reporter(nn.Module, ABC): """An ELK reporter network.""" diff --git a/elk/training/sweep.py b/elk/training/sweep.py index 17250def..5c2350d1 100644 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -1,8 +1,10 @@ from copy import deepcopy from dataclasses import InitVar, dataclass +from ..evaluation.evaluate import Eval from ..extraction import Extract, PromptConfig from ..files import elk_reporter_dir, memorably_named_dir +from ..utils import colorize from .train import Elicit @@ -47,6 +49,19 @@ def execute(self): sweep_dir = root_dir / self.name if self.name else memorably_named_dir(root_dir) print(f"Saving sweep results to \033[1m{sweep_dir}\033[0m") # bold + # Each dataset string can contain multiple datasets, delimited by plus; this + # indicates that the component datasets will be pooled together for training. + # For example, we might be sweeping over ["amazon_polarity", "imdb+sst2"]. For + # transfer eval, we want to split "imdb+sst2" into ["imdb", "sst2"] and then + # flatten the list, yielding ["amazon_polarity", "imdb", "sst2"]. + eval_datasets = sorted( + { + ds.strip() + for dataset_str in self.datasets + for ds in dataset_str.split("+") + } + ) + for i, model_str in enumerate(self.models): # Magenta color for the model name print(f"\n\033[35m===== {model_str} ({i + 1} of {M}) =====\033[0m") @@ -57,10 +72,29 @@ def execute(self): # Allow for multiple datasets to be specified in a single string with # plus signs. This means we can pool datasets together inside of a # single sweep. - datasets = [ds.strip() for ds in dataset_str.split("+")] + train_datasets = [ds.strip() for ds in dataset_str.split("+")] run = deepcopy(self.run_template) run.data.model = model_str - run.data.prompts.datasets = datasets + run.data.prompts.datasets = train_datasets run.out_dir = out_dir run.execute() + + if len(eval_datasets) > 1: + print(colorize("== Transfer eval ==", "green")) + + # Now evaluate the reporter on the other datasets + for eval_dataset in eval_datasets: + # We already evaluated on this one during training + if eval_dataset in train_datasets: + continue + + eval = Eval( + data=Extract( + model=model_str, + prompts=PromptConfig(datasets=[eval_dataset]), + ), + source=str(run.out_dir), + out_dir=out_dir, + ) + eval.execute(highlight_color="green") diff --git a/elk/training/train.py b/elk/training/train.py index 9a16ad01..8119cc28 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,67 +1,36 @@ """Main training loop.""" from dataclasses import dataclass -from functools import partial from pathlib import Path -from typing import Callable, Literal +from typing import Literal import pandas as pd import torch from einops import rearrange, repeat -from simple_parsing import Serializable, field, subgroups +from simple_parsing import subgroups -from ..extraction.extraction import Extract from ..metrics import evaluate_preds, to_one_hot from ..run import Run from ..training.supervised import train_supervised -from ..utils import select_usable_devices from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig -from .reporter import OptimConfig, ReporterConfig +from .reporter import ReporterConfig @dataclass -class Elicit(Serializable): - """Full specification of a reporter training run. - - Args: - data: Config specifying hidden states on which the reporter will be trained. - net: Config for building the reporter network. - optim: Config for the `.fit()` loop. - num_gpus: The number of GPUs to use. Defaults to -1, which means - "use all available GPUs". - normalization: The normalization method to use. Defaults to "meanonly". See - `elk.training.preprocessing.normalize()` for details. - supervised: Whether to train a supervised classifier, and if so, whether to - use cross-validation. Defaults to "single", which means to train a single - classifier on the training data. "cv" means to use cross-validation. - debug: When in debug mode, a useful log file is saved to the memorably-named - output directory. Defaults to False. - """ - - data: Extract +class Elicit(Run): + """Full specification of a reporter training run.""" + net: ReporterConfig = subgroups( {"ccs": CcsReporterConfig, "eigen": EigenReporterConfig}, default="eigen" ) - optim: OptimConfig = field(default_factory=OptimConfig) + """Config for building the reporter network.""" - concatenated_layer_offset: int = 0 - debug: bool = False - min_gpu_mem: int | None = None - num_gpus: int = -1 - out_dir: Path | None = None supervised: Literal["none", "single", "cv"] = "single" - - disable_cache: bool = field(default=False, to_dict=False) - - def execute(self): - Train(cfg=self, out_dir=self.out_dir).train() - - -@dataclass -class Train(Run): - cfg: Elicit + """Whether to train a supervised classifier, and if so, whether to use + cross-validation. Defaults to "single", which means to train a single classifier + on the training data. "cv" means to use cross-validation.""" def create_models_dir(self, out_dir: Path): lr_dir = None @@ -73,14 +42,15 @@ def create_models_dir(self, out_dir: Path): return reporter_dir, lr_dir - def train_reporter( + def apply_to_layer( self, layer: int, devices: list[str], - world_size: int = 1, + world_size: int, ) -> pd.DataFrame: """Train a single reporter on a single layer.""" - self.make_reproducible(seed=self.cfg.net.seed + layer) + + self.make_reproducible(seed=self.net.seed + layer) device = self.get_device(devices, world_size) train_dict = self.prepare_data(device, layer, "train") @@ -92,10 +62,10 @@ def train_reporter( raise ValueError("All datasets must have the same hidden state size") reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - if isinstance(self.cfg.net, CcsReporterConfig): + if isinstance(self.net, CcsReporterConfig): assert len(train_dict) == 1, "CCS only supports single-task training" - reporter = CcsReporter(self.cfg.net, d, device=device) + reporter = CcsReporter(self.net, d, device=device) train_loss = reporter.fit(first_train_h, train_labels) (val_h, val_gt, _) = next(iter(val_dict.values())) @@ -106,11 +76,11 @@ def train_reporter( val_pair=(reporter.neg_norm(val_x0), reporter.pos_norm(val_x1)), ) - elif isinstance(self.cfg.net, EigenReporterConfig): + elif isinstance(self.net, EigenReporterConfig): # We set num_classes to None to enable training on datasets with different # numbers of classes. Under the hood, this causes the covariance statistics # to be simply averaged across all batches passed to update(). - reporter = EigenReporter(self.cfg.net, d, num_classes=None, device=device) + reporter = EigenReporter(self.net, d, num_classes=None, device=device) hidden_list, label_list = [], [] for ds_name, (train_h, train_labels, _) in train_dict.items(): @@ -131,16 +101,16 @@ def train_reporter( torch.cat(hidden_list), ) else: - raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}") + raise ValueError(f"Unknown reporter config type: {type(self.net)}") # Save reporter checkpoint to disk with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: torch.save(reporter, file) # Fit supervised logistic regression model - if self.cfg.supervised != "none": + if self.supervised != "none": lr_model = train_supervised( - train_dict, device=device, cv=self.cfg.supervised == "cv" + train_dict, device=device, cv=self.supervised == "cv" ) with open(lr_dir / f"layer_{layer}.pt", "wb") as file: torch.save(lr_model, file) @@ -169,12 +139,3 @@ def train_reporter( row_buf.append(row) return pd.DataFrame.from_records(row_buf) - - def train(self): - """Train a reporter on each layer of the network.""" - devices = select_usable_devices(self.cfg.num_gpus) - num_devices = len(devices) - func: Callable[[int], pd.DataFrame] = partial( - self.train_reporter, devices=devices, world_size=num_devices - ) - self.apply_to_layers(func=func, num_devices=num_devices) diff --git a/elk/truncated_eigh.py b/elk/truncated_eigh.py index 1cca6c83..fe229837 100644 --- a/elk/truncated_eigh.py +++ b/elk/truncated_eigh.py @@ -16,7 +16,6 @@ class Eigendecomposition(NamedTuple): eigenvectors: Tensor -@torch.autocast("cuda", enabled=torch.cuda.is_available()) def truncated_eigh( A: Tensor, k: int = 1, diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index d9079147..ede58bea 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -11,6 +11,7 @@ from .gpu_utils import select_usable_devices from .hf_utils import instantiate_model, instantiate_tokenizer, is_autoregressive from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained +from .pretty import colorize from .tree_utils import pytree_map from .typing import assert_type, float32_to_int16, int16_to_float32 @@ -18,6 +19,7 @@ "assert_type", "batch_cov", "binarize", + "colorize", "cov_mean_fused", "float32_to_int16", "get_columns_all_equal", diff --git a/elk/utils/pretty.py b/elk/utils/pretty.py new file mode 100644 index 00000000..6552dc1b --- /dev/null +++ b/elk/utils/pretty.py @@ -0,0 +1,24 @@ +# Kind of kickass that this file has no imports + +# ANSI color codes for use in terminal output. +COLOR_CODES = { + "black": 30, + "red": 31, + "green": 32, + "yellow": 33, + "blue": 34, + "magenta": 35, + "cyan": 36, + "white": 37, +} + + +def colorize(message: str, color: str) -> str: + """Colorize a message for terminal output.""" + # Get the ANSI color code based on the human-readable color name. + code = COLOR_CODES.get(color.lower()) + if code is None: + raise ValueError(f"Invalid color name: {color}") + + # Construct and return the colored message. + return f"\033[{code}m{message}\033[0m" From 57d0b8b754d7c856794b4e76ae1e6542fe0c2102 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 22 Apr 2023 23:23:27 +0000 Subject: [PATCH 08/17] Support INLP and split eval output into multiple CSVs --- elk/evaluation/evaluate.py | 34 +++++++++++---------- elk/run.py | 24 ++++++++------- elk/training/classifier.py | 60 ++++++++++++++++++++++++++++++++++++-- elk/training/supervised.py | 19 ++++++++---- elk/training/train.py | 48 +++++++++++++++++------------- 5 files changed, 129 insertions(+), 56 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 3329d184..216f9797 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,3 +1,4 @@ +from collections import defaultdict from dataclasses import dataclass import pandas as pd @@ -25,7 +26,7 @@ def __post_init__(self): def apply_to_layer( self, layer: int, devices: list[str], world_size: int - ) -> pd.DataFrame: + ) -> dict[str, pd.DataFrame]: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) val_output = self.prepare_data(device, layer, "val") @@ -36,24 +37,25 @@ def apply_to_layer( reporter: Reporter = torch.load(reporter_path, map_location=device) reporter.eval() - row_buf = [] + row_bufs = defaultdict(list) for ds_name, (val_h, val_gt, _) in val_output.items(): - val_result = evaluate_preds(val_gt, reporter(val_h)) + meta = {"dataset": ds_name, "layer": layer} - stats_row = { - "dataset": ds_name, - "layer": layer, - **val_result.to_dict(), - } + val_result = evaluate_preds(val_gt, reporter(val_h)) + row_bufs["eval"].append({**meta, **val_result.to_dict()}) lr_dir = experiment_dir / "lr_models" if not self.skip_supervised and lr_dir.exists(): with open(lr_dir / f"layer_{layer}.pt", "rb") as f: - lr_model = torch.load(f, map_location=device).eval() - - lr_result = evaluate_preds(val_gt, lr_model(val_h)) - stats_row.update(lr_result.to_dict(prefix="lr_")) - - row_buf.append(stats_row) - - return pd.DataFrame.from_records(row_buf) + lr_models = torch.load(f, map_location=device) + if not isinstance(lr_models, list): # backward compatibility + lr_models = [lr_models] + + for i, model in enumerate(lr_models): + model.eval() + lr_result = evaluate_preds(val_gt, model(val_h)) + row_bufs["lr_eval"].append( + {"inlp_iter": i, **meta, **lr_result.to_dict()} + ) + + return {k: pd.DataFrame(v) for k, v in row_bufs.items()} diff --git a/elk/run.py b/elk/run.py index 0003131b..7cf130b7 100644 --- a/elk/run.py +++ b/elk/run.py @@ -1,6 +1,7 @@ import os import random from abc import ABC, abstractmethod +from collections import defaultdict from dataclasses import dataclass from functools import partial from pathlib import Path @@ -88,7 +89,7 @@ def execute(self, highlight_color: str = "cyan"): devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem) num_devices = len(devices) - func: Callable[[int], pd.DataFrame] = partial( + func: Callable[[int], dict[str, pd.DataFrame]] = partial( self.apply_to_layer, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) @@ -96,7 +97,7 @@ def execute(self, highlight_color: str = "cyan"): @abstractmethod def apply_to_layer( self, layer: int, devices: list[str], world_size: int - ) -> pd.DataFrame: + ) -> dict[str, pd.DataFrame]: """Train or eval a reporter on a single layer.""" def make_reproducible(self, seed: int): @@ -145,7 +146,7 @@ def concatenate(self, layers): def apply_to_layers( self, - func: Callable[[int], pd.DataFrame], + func: Callable[[int], dict[str, pd.DataFrame]], num_devices: int, ): """Apply a function to each layer of the datasets in parallel @@ -165,17 +166,18 @@ def apply_to_layers( layers = self.concatenate(layers) ctx = mp.get_context("spawn") - with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: + with ctx.Pool(num_devices) as pool: mapper = pool.imap_unordered if num_devices > 1 else map - df_buf = [] + df_buffers = defaultdict(list) try: - for df in tqdm(mapper(func, layers), total=len(layers)): - df_buf.append(df) + for df_dict in tqdm(mapper(func, layers), total=len(layers)): + for k, v in df_dict.items(): + df_buffers[k].append(v) finally: - # Make sure the CSV is written even if we crash or get interrupted - if df_buf: - df = pd.concat(df_buf).sort_values(by="layer") - df.round(4).to_csv(f, index=False) + # Make sure the CSVs are written even if we crash or get interrupted + for name, dfs in df_buffers.items(): + df = pd.concat(dfs).sort_values(by="layer") + df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False) if self.debug: save_debug_log(self.datasets, self.out_dir) diff --git a/elk/training/classifier.py b/elk/training/classifier.py index b92d0f7e..148da939 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field import torch from torch import Tensor @@ -10,6 +10,14 @@ ) +@dataclass +class InlpResult: + """Result of Iterative Nullspace Projection (NLP).""" + + losses: list[float] = field(default_factory=list) + classifiers: list["Classifier"] = field(default_factory=list) + + @dataclass class RegularizationPath: """Result of cross-validation.""" @@ -175,10 +183,58 @@ def fit_cv( self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter) return RegularizationPath(l2_penalties, mean_losses.tolist()) + @classmethod + def inlp( + cls, x: Tensor, y: Tensor, max_iter: int | None = None, tol: float = 0.01 + ) -> InlpResult: + """Iterative Nullspace Projection (INLP) . + + Args: + x: Input tensor of shape (N, D), where N is the number of samples and D is + the input dimension. + y: Target tensor of shape (N,) for binary classification or (N, C) for + multiclass classification, where C is the number of classes. + max_iter: Maximum number of iterations to run. If `None`, run for the full + dimension of the input. + tol: Tolerance for the loss function. The algorithm will stop when the loss + is within `tol` of the entropy of the labels. + + Returns: + `InlpResult` containing the classifiers and losses achieved at each + iteration. + """ + + y.shape[-1] if y.ndim > 1 else 2 + d = x.shape[-1] + loss = 0.0 + + # Compute entropy of the labels + p = y.float().mean() + H = -p * torch.log(p) - (1 - p) * torch.log(1 - p) + + if max_iter is not None: + d = min(d, max_iter) + + # Iterate until the loss is within epsilon of the entropy + result = InlpResult() + for _ in range(d): + clf = cls(d, device=x.device, dtype=x.dtype) + loss = clf.fit(x, y) + result.classifiers.append(clf) + result.losses.append(loss) + + if loss >= (1.0 - tol) * H: + break + + # Project the data onto the nullspace of the classifier + x = clf.nullspace_project(x) + + return result + def nullspace_project(self, x: Tensor) -> Tensor: """Project the given data onto the nullspace of the classifier.""" # https://en.wikipedia.org/wiki/Projection_(linear_algebra) A = self.linear.weight.data.T P = A @ torch.linalg.solve(A.mT @ A, A.mT) - return x - P @ x + return x - x @ P diff --git a/elk/training/supervised.py b/elk/training/supervised.py index 86531106..d2eef5f7 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -5,7 +5,9 @@ from .classifier import Classifier -def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifier: +def train_supervised( + data: dict[str, tuple], device: str, mode: str +) -> list[Classifier]: Xs, train_labels = [], [] for train_h, labels, _ in data.values(): @@ -19,10 +21,15 @@ def train_supervised(data: dict[str, tuple], device: str, cv: bool) -> Classifie train_labels.append(labels) X, train_labels = torch.cat(Xs), torch.cat(train_labels) - lr_model = Classifier(X.shape[-1], device=device) - if cv: + if mode == "cv": + lr_model = Classifier(X.shape[-1], device=device) lr_model.fit_cv(X, train_labels) - else: + return [lr_model] + elif mode == "inlp": + return Classifier.inlp(X, train_labels).classifiers + elif mode == "single": + lr_model = Classifier(X.shape[-1], device=device) lr_model.fit(X, train_labels) - - return lr_model + return [lr_model] + else: + raise ValueError(f"Unknown mode: {mode}") diff --git a/elk/training/train.py b/elk/training/train.py index 8119cc28..ddecc06f 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,5 +1,6 @@ """Main training loop.""" +from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Literal @@ -27,7 +28,7 @@ class Elicit(Run): ) """Config for building the reporter network.""" - supervised: Literal["none", "single", "cv"] = "single" + supervised: Literal["none", "single", "inlp", "cv"] = "single" """Whether to train a supervised classifier, and if so, whether to use cross-validation. Defaults to "single", which means to train a single classifier on the training data. "cv" means to use cross-validation.""" @@ -47,7 +48,7 @@ def apply_to_layer( layer: int, devices: list[str], world_size: int, - ) -> pd.DataFrame: + ) -> dict[str, pd.DataFrame]: """Train a single reporter on a single layer.""" self.make_reproducible(seed=self.net.seed + layer) @@ -109,33 +110,38 @@ def apply_to_layer( # Fit supervised logistic regression model if self.supervised != "none": - lr_model = train_supervised( - train_dict, device=device, cv=self.supervised == "cv" + lr_models = train_supervised( + train_dict, + device=device, + mode=self.supervised, ) with open(lr_dir / f"layer_{layer}.pt", "wb") as file: - torch.save(lr_model, file) + torch.save(lr_models, file) else: - lr_model = None + lr_models = [] - row_buf = [] + row_bufs = defaultdict(list) for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items(): + meta = {"dataset": ds_name, "layer": layer} + val_result = evaluate_preds(val_gt, reporter(val_h)) - row = { - "dataset": ds_name, - "layer": layer, - "pseudo_auroc": pseudo_auroc, - "train_loss": train_loss, - **val_result.to_dict(), - } + row_bufs["eval"].append( + { + **meta, + "pseudo_auroc": pseudo_auroc, + "train_loss": train_loss, + **val_result.to_dict(), + } + ) if val_lm_preds is not None: lm_result = evaluate_preds(val_gt, val_lm_preds) - row.update(lm_result.to_dict(prefix="lm_")) + row_bufs["lm_eval"].append({**meta, **lm_result.to_dict()}) - if lr_model is not None: - lr_result = evaluate_preds(val_gt, lr_model(val_h)) - row.update(lr_result.to_dict(prefix="lr_")) - - row_buf.append(row) + for i, model in enumerate(lr_models): + lr_result = evaluate_preds(val_gt, model(val_h)) + row_bufs["lr_eval"].append( + {"inlp_iter": i, **meta, **lr_result.to_dict()} + ) - return pd.DataFrame.from_records(row_buf) + return {k: pd.DataFrame(v) for k, v in row_bufs.items()} From b086f0b31df478395b6bf3819ee2883a2d9be11f Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 25 Apr 2023 04:22:50 +0000 Subject: [PATCH 09/17] Merge branch 'inlp' into template-filtering --- elk/evaluation/evaluate.py | 3 ++ elk/extraction/extraction.py | 78 ++++++++++++++++++------------------ elk/extraction/generator.py | 7 ---- elk/run.py | 7 +++- elk/utils/__init__.py | 2 + elk/utils/data_utils.py | 25 ++++++++---- 6 files changed, 68 insertions(+), 54 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 216f9797..52b7f3d4 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -24,6 +24,9 @@ def __post_init__(self): transfer_dir = elk_reporter_dir() / self.source / "transfer_eval" self.out_dir = transfer_dir / "+".join(self.data.prompts.datasets) + def execute(self, highlight_color: str = "cyan"): + return super().execute(highlight_color, split_type="val") + def apply_to_layer( self, layer: int, devices: list[str], world_size: int ) -> dict[str, pd.DataFrame]: diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 81c9297f..becb7b35 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -12,6 +12,7 @@ Array2D, Array3D, DatasetDict, + DatasetInfo, DownloadMode, Features, Sequence, @@ -35,6 +36,7 @@ instantiate_model, instantiate_tokenizer, is_autoregressive, + select_split, select_train_val_splits, select_usable_devices, ) @@ -264,39 +266,8 @@ def _extraction_worker(**kwargs): yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) -def extract( - cfg: "Extract", - *, - disable_cache: bool = False, - highlight_color: str = "cyan", - num_gpus: int = -1, - min_gpu_mem: int | None = None, -) -> DatasetDict: - """Extract hidden states from a model and return a `DatasetDict` containing them.""" - - def get_splits() -> SplitDict: - available_splits = assert_type(SplitDict, info.splits) - train_name, val_name = select_train_val_splits(available_splits) - - pretty_name = colorize(assert_type(str, info.builder_name), highlight_color) - print( - f"{pretty_name}: using '{train_name}' for training " - f"and '{val_name}' for validation" - ) - limit_list = cfg.prompts.max_examples - - return SplitDict( - { - k: SplitInfo( - name=k, - num_examples=min(limit, v.num_examples) * len(cfg.prompts.datasets), - dataset_name=v.dataset_name, - ) - for limit, (k, v) in zip(limit_list, available_splits.items()) - }, - dataset_name=available_splits.dataset_name, - ) - +def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]: + """Return the HuggingFace `Features` corresponding to an `Extract` config.""" model_cfg = AutoConfig.from_pretrained(cfg.model) ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") @@ -345,16 +316,47 @@ def get_splits() -> SplitDict: dtype="float32", ) + return info, Features({**layer_cols, **other_cols}) + + +def extract( + cfg: "Extract", + *, + disable_cache: bool = False, + highlight_color: str = "cyan", + num_gpus: int = -1, + min_gpu_mem: int | None = None, + split_type: Literal["train", "val", None] = None, +) -> DatasetDict: + """Extract hidden states from a model and return a `DatasetDict` containing them.""" + info, features = hidden_features(cfg) + devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem) + limit_list = cfg.prompts.max_examples + splits = assert_type(SplitDict, info.splits) + + if split_type is None: + train, val = select_train_val_splits(splits) + pretty_name = colorize(assert_type(str, info.builder_name), highlight_color) + print(f"{pretty_name}: using '{train}' for training and '{val}' for validation") + splits = SplitDict({train: splits[train], val: splits[val]}) + else: + # Remove the split we're not using + del limit_list[1 if split_type == "train" else 0] + split_name = select_split(splits, split_type) + splits = SplitDict({split_name: splits[split_name]}) + builders = { split_name: _GeneratorBuilder( - builder_name=info.builder_name, - config_name=info.config_name, cache_dir=None, - features=Features({**layer_cols, **other_cols}), + features=features, generator=_extraction_worker, split_name=split_name, - split_info=split_info, + split_info=SplitInfo( + name=split_name, + num_examples=min(limit, v.num_examples) * len(cfg.prompts.datasets), + dataset_name=v.dataset_name, + ), gen_kwargs=dict( cfg=[cfg] * len(devices), device=devices, @@ -363,7 +365,7 @@ def get_splits() -> SplitDict: world_size=[len(devices)] * len(devices), ), ) - for (split_name, split_info) in get_splits().items() + for limit, (split_name, v) in zip(limit_list, splits.items()) } import multiprocess as mp diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 86e65e08..84818c83 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -58,8 +58,6 @@ class _GeneratorBuilder(GeneratorBasedBuilder): def __init__( self, - builder_name: str | None, - config_name: str | None, split_name: str, split_info: SplitInfo, **kwargs, @@ -69,11 +67,6 @@ def __init__( super().__init__(**kwargs) - # Weirdly we need to set DatasetInfo.builder_name and DatasetInfo.config_name - # here, not in _info, because super().__init__ modifies them - self.info.builder_name = builder_name - self.info.config_name = config_name - def _info(self): # Use the same builder and config name as the original builder return DatasetInfo(features=self.config.features) diff --git a/elk/run.py b/elk/run.py index 2e39e401..fcb0dc24 100644 --- a/elk/run.py +++ b/elk/run.py @@ -47,7 +47,11 @@ class Run(ABC, Serializable): out_dir: Path | None = None disable_cache: bool = field(default=False, to_dict=False) - def execute(self, highlight_color: str = "cyan"): + def execute( + self, + highlight_color: str = "cyan", + split_type: Literal["train", "val", None] = None, + ): self.datasets = [ extract( cfg, @@ -55,6 +59,7 @@ def execute(self, highlight_color: str = "cyan"): highlight_color=highlight_color, num_gpus=self.num_gpus, min_gpu_mem=self.min_gpu_mem, + split_type=split_type, ) for cfg in self.data.explode() ] diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index ede58bea..c2cf7e7a 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -6,6 +6,7 @@ has_multiple_configs, infer_label_column, infer_num_classes, + select_split, select_train_val_splits, ) from .gpu_utils import select_usable_devices @@ -33,6 +34,7 @@ "int16_to_float32", "is_autoregressive", "pytree_map", + "select_split", "select_train_val_splits", "select_usable_devices", "stochastic_round_constrained", diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 181ed974..af907694 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -1,7 +1,7 @@ import copy from functools import cache from random import Random -from typing import Any, Iterable +from typing import Any, Iterable, Literal from datasets import ( ClassLabel, @@ -15,6 +15,13 @@ from ..promptsource.templates import Template from .typing import assert_type +# Lower values are more "train-like" and higher values are more "test-like". +PRIORITIES = { + Split.TRAIN: 0, + Split.VALIDATION: 1, + Split.TEST: 2, +} + def get_columns_all_equal(dataset: DatasetDict) -> list[str]: """Get columns of a `DatasetDict`, asserting all splits have the same columns.""" @@ -50,16 +57,18 @@ def has_multiple_configs(ds_name: str) -> bool: return len(get_dataset_config_names(ds_name)) > 1 +def select_split(raw_splits: Iterable[str], split_type: Literal["train", "val"]) -> str: + """Return the train or validation split to use, given an Iterable of splits.""" + assert split_type in ("train", "val") + + reduce_fn = min if split_type == "train" else max + return reduce_fn(raw_splits, key=lambda k: PRIORITIES.get(k, 100)) # type: ignore + + def select_train_val_splits(raw_splits: Iterable[str]) -> tuple[str, str]: """Return splits to use for train and validation, given an Iterable of splits.""" - priorities = { - Split.TRAIN: 0, - Split.VALIDATION: 1, - Split.TEST: 2, - } - - splits = sorted(raw_splits, key=lambda k: priorities.get(k, 100)) # type: ignore + splits = sorted(raw_splits, key=lambda k: PRIORITIES.get(k, 100)) # type: ignore assert len(splits) >= 2, "Must have at least two of train, val, and test splits" return tuple(splits[:2]) From 934cd54e9123fcaa6de36c42957e0229ea10e600 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 26 Apr 2023 10:19:33 +0000 Subject: [PATCH 10/17] Log ensembled metrics --- elk/evaluation/evaluate.py | 39 +++++++++++++++++++------------ elk/metrics/eval.py | 36 +++++++++++++++++++++-------- elk/run.py | 2 +- elk/training/train.py | 47 +++++++++++++++++++++++--------------- 4 files changed, 82 insertions(+), 42 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 57ee6d03..2b9e8509 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -45,21 +45,32 @@ def apply_to_layer( for ds_name, (val_h, val_gt, _) in val_output.items(): meta = {"dataset": ds_name, "layer": layer} - val_result = evaluate_preds(val_gt, reporter(val_h)) - row_bufs["eval"].append({**meta, **val_result.to_dict()}) + val_credences = reporter(val_h) + for mode in ("none", "partial", "full"): + row_bufs["eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(val_gt, val_credences, mode).to_dict(), + } + ) - lr_dir = experiment_dir / "lr_models" - if not self.skip_supervised and lr_dir.exists(): - with open(lr_dir / f"layer_{layer}.pt", "rb") as f: - lr_models = torch.load(f, map_location=device) - if not isinstance(lr_models, list): # backward compatibility - lr_models = [lr_models] + lr_dir = experiment_dir / "lr_models" + if not self.skip_supervised and lr_dir.exists(): + with open(lr_dir / f"layer_{layer}.pt", "rb") as f: + lr_models = torch.load(f, map_location=device) + if not isinstance(lr_models, list): # backward compatibility + lr_models = [lr_models] - for i, model in enumerate(lr_models): - model.eval() - lr_result = evaluate_preds(val_gt, model(val_h)) - row_bufs["lr_eval"].append( - {"inlp_iter": i, **meta, **lr_result.to_dict()} - ) + for i, model in enumerate(lr_models): + model.eval() + row_bufs["lr_eval"].append( + { + "ensembling": mode, + "inlp_iter": i, + **meta, + **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + } + ) return {k: pd.DataFrame(v) for k, v in row_bufs.items()} diff --git a/elk/metrics/eval.py b/elk/metrics/eval.py index dcc5ce35..653beae5 100644 --- a/elk/metrics/eval.py +++ b/elk/metrics/eval.py @@ -1,4 +1,5 @@ from dataclasses import asdict, dataclass +from typing import Literal import torch from einops import repeat @@ -37,16 +38,20 @@ def to_dict(self, prefix: str = "") -> dict[str, float]: else {} ) auroc_dict = {f"{prefix}auroc_{k}": v for k, v in asdict(self.roc_auc).items()} - return {**acc_dict, **cal_acc_dict, **cal_dict, **auroc_dict} + return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict} -def evaluate_preds(y_true: Tensor, y_logits: Tensor) -> EvalResult: +def evaluate_preds( + y_true: Tensor, + y_logits: Tensor, + ensembling: Literal["none", "partial", "full"] = "none", +) -> EvalResult: """ Evaluate the performance of a classification model. Args: y_true: Ground truth tensor of shape (N,). - y_pred: Predicted class tensor of shape (N, variants, n_classes). + y_logits: Predicted class tensor of shape (N, variants, n_classes). Returns: dict: A dictionary containing the accuracy, AUROC, and ECE. @@ -54,16 +59,29 @@ def evaluate_preds(y_true: Tensor, y_logits: Tensor) -> EvalResult: (n, v, c) = y_logits.shape assert y_true.shape == (n,) - # Clustered bootstrap confidence intervals for AUROC - y_true = repeat(y_true, "n -> n v", v=v) - auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1)) - acc = accuracy_ci(y_true, y_logits.argmax(dim=-1)) - + if ensembling == "full": + y_logits = y_logits.mean(dim=1) + else: + y_true = repeat(y_true, "n -> n v", v=v) + + y_pred = y_logits.argmax(dim=-1) + if ensembling == "none": + auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1)) + elif ensembling in ("partial", "full"): + # Pool together the negative and positive class logits + if c == 2: + auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0]) + else: + auroc = roc_auc_ci(to_one_hot(y_true, c).long(), y_logits) + else: + raise ValueError(f"Unknown mode: {ensembling}") + + acc = accuracy_ci(y_true, y_pred) cal_acc = None cal_err = None if c == 2: - pos_probs = y_logits.softmax(-1)[..., 1] + pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0]) # Calibrated accuracy cal_thresh = pos_probs.float().quantile(y_true.float().mean()) diff --git a/elk/run.py b/elk/run.py index 838c228f..d7fa549e 100644 --- a/elk/run.py +++ b/elk/run.py @@ -173,7 +173,7 @@ def apply_to_layers( finally: # Make sure the CSVs are written even if we crash or get interrupted for name, dfs in df_buffers.items(): - df = pd.concat(dfs).sort_values(by="layer") + df = pd.concat(dfs).sort_values(by=["layer", "ensembling"]) df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False) if self.debug: save_debug_log(self.datasets, self.out_dir) diff --git a/elk/training/train.py b/elk/training/train.py index ddecc06f..ad5a799a 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -124,24 +124,35 @@ def apply_to_layer( for ds_name, (val_h, val_gt, val_lm_preds) in val_dict.items(): meta = {"dataset": ds_name, "layer": layer} - val_result = evaluate_preds(val_gt, reporter(val_h)) - row_bufs["eval"].append( - { - **meta, - "pseudo_auroc": pseudo_auroc, - "train_loss": train_loss, - **val_result.to_dict(), - } - ) - - if val_lm_preds is not None: - lm_result = evaluate_preds(val_gt, val_lm_preds) - row_bufs["lm_eval"].append({**meta, **lm_result.to_dict()}) - - for i, model in enumerate(lr_models): - lr_result = evaluate_preds(val_gt, model(val_h)) - row_bufs["lr_eval"].append( - {"inlp_iter": i, **meta, **lr_result.to_dict()} + val_credences = reporter(val_h) + for mode in ("none", "partial", "full"): + row_bufs["eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(val_gt, val_credences, mode).to_dict(), + "pseudo_auroc": pseudo_auroc, + "train_loss": train_loss, + } ) + if val_lm_preds is not None: + row_bufs["lm_eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), + } + ) + + for i, model in enumerate(lr_models): + row_bufs["lr_eval"].append( + { + **meta, + "ensembling": mode, + "inlp_iter": i, + **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + } + ) + return {k: pd.DataFrame(v) for k, v in row_bufs.items()} From dff69bf7184c4ea2d53043ed0009c7ebaf658f52 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 26 Apr 2023 10:42:04 +0000 Subject: [PATCH 11/17] Fixing pyright version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6575e57a..1a787fba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dev = [ "hypothesis", "pre-commit", "pytest", - "pyright", + "pyright==1.1.304", "scikit-learn", ] From 69c2d557ddda039a0db364503f65425d2ffa7126 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 27 Apr 2023 01:45:19 +0000 Subject: [PATCH 12/17] Tons of stuff, preparing for sciq_binary experiment --- elk/evaluation/evaluate.py | 7 +- elk/extraction/__init__.py | 3 +- elk/extraction/dataset_name.py | 2 +- elk/extraction/extraction.py | 122 +++++++++---- elk/extraction/prompt_loading.py | 163 +++++------------- elk/files.py | 2 +- elk/promptsource/templates.py | 21 ++- .../templates/sciq_binary/templates.yaml | 160 +++++++++++++++++ elk/run.py | 11 +- elk/training/sweep.py | 26 ++- elk/utils/__init__.py | 4 +- elk/utils/data_utils.py | 53 +++--- elk/utils/gpu_utils.py | 8 + elk/utils/pretty.py | 5 +- tests/dbpedia_prompts.yaml | 3 +- tests/super_glue_prompts.yaml | 3 +- tests/test_load_prompts.py | 10 +- tests/test_smoke_elicit.py | 7 +- tests/test_smoke_eval.py | 20 ++- 19 files changed, 400 insertions(+), 230 deletions(-) create mode 100644 elk/promptsource/templates/sciq_binary/templates.yaml diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 46ac8911..cf0508df 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,5 +1,6 @@ from collections import defaultdict from dataclasses import dataclass +from pathlib import Path import pandas as pd import torch @@ -15,7 +16,10 @@ class Eval(Run): """Full specification of a reporter evaluation run.""" - source: str = field(default="", positional=True) + # Using None as a default here is a hack; we actually raise an error if it's not + # specified in __post_init__. TODO: Maybe this is an indication we should be using + # composition and not inheritance here? + source: Path | None = field(default=None, positional=True) skip_supervised: bool = False def __post_init__(self): @@ -38,6 +42,7 @@ def apply_to_layer( device = self.get_device(devices, world_size) val_output = self.prepare_data(device, layer, "val") + assert self.source, "Must specify a source experiment." experiment_dir = elk_reporter_dir() / self.source reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt" diff --git a/elk/extraction/__init__.py b/elk/extraction/__init__.py index fac876fa..3fab2e32 100644 --- a/elk/extraction/__init__.py +++ b/elk/extraction/__init__.py @@ -1,7 +1,7 @@ from .balanced_sampler import BalancedSampler, FewShotSampler from .extraction import Extract, extract, extract_hiddens from .generator import _GeneratorBuilder, _GeneratorConfig -from .prompt_loading import PromptConfig, load_prompts +from .prompt_loading import load_prompts __all__ = [ "BalancedSampler", @@ -11,6 +11,5 @@ "extract", "_GeneratorConfig", "_GeneratorBuilder", - "PromptConfig", "load_prompts", ] diff --git a/elk/extraction/dataset_name.py b/elk/extraction/dataset_name.py index ec850d01..7ea05e21 100644 --- a/elk/extraction/dataset_name.py +++ b/elk/extraction/dataset_name.py @@ -5,7 +5,7 @@ def extract_dataset_name_and_config(dataset_config_str: str) -> tuple[str, str]: """Extract the dataset name and config name from the dataset prompt.""" - ds_name, _, config_name = dataset_config_str.partition(" ") + ds_name, _, config_name = dataset_config_str.partition(":") return ds_name, config_name diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index eca30d30..631bb5df 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,9 +1,8 @@ """Functions for extracting the hidden states of a model.""" import logging import os -from copy import copy -from dataclasses import InitVar, dataclass -from itertools import islice +from dataclasses import InitVar, dataclass, replace +from itertools import islice, zip_longest from typing import Any, Iterable, Literal from warnings import filterwarnings @@ -45,31 +44,70 @@ extract_dataset_name_and_config, ) from .generator import _GeneratorBuilder -from .prompt_loading import PromptConfig, load_prompts +from .prompt_loading import load_prompts @dataclass class Extract(Serializable): - """ - Args: - model: HuggingFace model string identifying the language model to extract - hidden states from. - prompts: The configuration for the prompt prompts. - layers: The layers to extract hidden states from. - layer_stride: Shortcut for setting `layers` to `range(0, num_layers, stride)`. - token_loc: The location of the token to extract hidden states from. Can be - either "first", "last", or "mean". Defaults to "last". - """ - - prompts: PromptConfig + """Config for extracting hidden states from a language model.""" + model: str = field(positional=True) + """HF model string identifying the language model to extract hidden states from.""" + + datasets: tuple[str, ...] = field(positional=True) + """Names of HF datasets to use, e.g. `"super_glue:boolq"` or `"imdb"`""" + + data_dirs: tuple[str, ...] = () + """Directory to use for caching the hiddens. Defaults to `HF_DATASETS_CACHE`.""" + + max_examples: tuple[int, int] = (1000, 1000) + """Maximum number of examples to use from each split of the dataset.""" + + num_shots: int = 0 + """Number of examples for few-shot prompts. If zero, prompts are zero-shot.""" + + num_variants: int = -1 + """The number of prompt templates to use for each example. If -1, all available + templates are used.""" + + seed: int = 42 + """Seed to use for prompt randomization. Defaults to 42.""" layers: tuple[int, ...] = () + """Indices of layers to extract hidden states from. We follow the HF convention, so + 0 is the embedding, and 1 is the output of the first transformer layer.""" + layer_stride: InitVar[int] = 1 + """Shortcut for `layers = (0,) + tuple(range(1, num_layers + 1, stride))`.""" + + template_path: str | None = None + """Path to pass into `DatasetTemplates`. By default we use the dataset name.""" + token_loc: Literal["first", "last", "mean"] = "last" + """The location of the token to extract hidden states from.""" + use_encoder_states: bool = False + """Whether to extract hidden states from the encoder instead of the decoder in the + case of encoder-decoder models.""" def __post_init__(self, layer_stride: int): + if len(self.max_examples) > 2: + raise ValueError( + "max_examples should be a list of length 0, 1, or 2," + f"but got {len(self.max_examples)}" + ) + if not self.max_examples: + self.max_examples = (int(1e100), int(1e100)) + + # Broadcast the dataset name to all data_dirs + if len(self.data_dirs) == 1: + self.data_dirs *= len(self.datasets) + elif self.data_dirs and len(self.data_dirs) != len(self.datasets): + raise ValueError( + "data_dirs should be a list of length 0, 1, or len(datasets)," + f" but got {len(self.data_dirs)}" + ) + if self.layers and layer_stride > 1: raise ValueError( "Cannot use both --layers and --layer-stride. Please use only one." @@ -87,14 +125,10 @@ def __post_init__(self, layer_stride: int): def explode(self) -> list["Extract"]: """Explode this config into a list of configs, one for each layer.""" - copies = [] - - for prompt_cfg in self.prompts.explode(): - cfg = copy(self) - cfg.prompts = prompt_cfg - copies.append(cfg) - - return copies + return [ + replace(self, datasets=(ds,), data_dirs=(data_dir,) if data_dir else ()) + for ds, data_dir in zip_longest(self.datasets, self.data_dirs) + ] @torch.inference_mode() @@ -114,8 +148,7 @@ def extract_hiddens( filterwarnings("ignore") logging.disable(logging.CRITICAL) - p_cfg = cfg.prompts - ds_names = p_cfg.datasets + ds_names = cfg.datasets assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." model = instantiate_model( @@ -138,6 +171,7 @@ def extract_hiddens( prompt_ds = load_prompts( ds_names[0], split_type=split_type, + template_path=cfg.template_path, rank=rank, world_size=world_size, ) @@ -145,7 +179,7 @@ def extract_hiddens( # Add one to the number of layers to account for the embedding layer layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers + 1)) - global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1] + global_max_examples = cfg.max_examples[0 if split_type == "train" else 1] # break `max_examples` among the processes roughly equally max_examples = global_max_examples // world_size # the last process gets the remainder (which is usually small) @@ -277,18 +311,22 @@ def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]: model_cfg = AutoConfig.from_pretrained(cfg.model) ds_name, config_name = extract_dataset_name_and_config( - dataset_config_str=cfg.prompts.datasets[0] + dataset_config_str=cfg.datasets[0] ) info = get_dataset_config_info(ds_name, config_name or None) - prompter = DatasetTemplates(ds_name, config_name) + if not cfg.template_path: + prompter = DatasetTemplates(ds_name, config_name) + else: + prompter = DatasetTemplates(cfg.template_path) + ds_features = assert_type(Features, info.features) label_col = prompter.label_column or infer_label_column(ds_features) num_classes = len(prompter.label_choices) or infer_num_classes( ds_features[label_col] ) - num_variants = cfg.prompts.num_variants + num_variants = cfg.num_variants if num_variants < 0: num_dropped = prompter.drop_non_mc_templates() num_variants = len(prompter.templates) @@ -340,19 +378,27 @@ def extract( info, features = hidden_features(cfg) devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem) - limit_list = cfg.prompts.max_examples + limits = cfg.max_examples splits = assert_type(SplitDict, info.splits) + pretty_name = colorize(assert_type(str, cfg.datasets[0]), highlight_color) if split_type is None: train, val = select_train_val_splits(splits) - pretty_name = colorize(assert_type(str, info.builder_name), highlight_color) - print(f"{pretty_name}: using '{train}' for training and '{val}' for validation") + + print(f"{pretty_name} using '{train}' for training and '{val}' for validation") splits = SplitDict({train: splits[train], val: splits[val]}) + split_types = ["train", "val"] else: # Remove the split we're not using - del limit_list[1 if split_type == "train" else 0] + limits = [limits[0]] if split_type == "train" else limits split_name = select_split(splits, split_type) splits = SplitDict({split_name: splits[split_name]}) + split_types = [split_type] + + if split_type == "train": + print(f"{pretty_name} using '{split_name}' for training") + else: + print(f"{pretty_name} using '{split_name}' for validation") builders = { split_name: _GeneratorBuilder( @@ -362,18 +408,18 @@ def extract( split_name=split_name, split_info=SplitInfo( name=split_name, - num_examples=min(limit, v.num_examples) * len(cfg.prompts.datasets), + num_examples=min(limit, v.num_examples) * len(cfg.datasets), dataset_name=v.dataset_name, ), gen_kwargs=dict( cfg=[cfg] * len(devices), device=devices, rank=list(range(len(devices))), - split_type=[split_name] * len(devices), + split_type=[ty] * len(devices), world_size=[len(devices)] * len(devices), ), ) - for limit, (split_name, v) in zip(limit_list, splits.items()) + for limit, (split_name, v), ty in zip(limits, splits.items(), split_types) } import multiprocess as mp @@ -389,6 +435,6 @@ def extract( dataset_dict = DatasetDict(ds) return DatasetDictWithName( - name=cfg.prompts.datasets[0], + name=cfg.datasets[0], dataset=dataset_dict, ) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index d0055bee..2d17bdf4 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -1,83 +1,18 @@ from collections import Counter -from copy import deepcopy -from dataclasses import dataclass -from itertools import zip_longest from random import Random from typing import Any, Iterator, Literal from datasets import ClassLabel, Dataset, Value, load_dataset -from simple_parsing.helpers import Serializable, field from ..promptsource import DatasetTemplates from ..utils import ( assert_type, infer_label_column, - select_train_val_splits, + select_split, ) from .balanced_sampler import BalancedSampler, FewShotSampler -@dataclass -class PromptConfig(Serializable): - """ - Args: - dataset: List of space-delimited names of the HuggingFace dataset to use, e.g. - `"super_glue boolq"` or `"imdb"`. - data_dir: The directory to use for caching the dataset. Defaults to - `~/.cache/huggingface/datasets`. - max_examples: The maximum number of examples to use from the val dataset. - If a single number, use at most that many examples for each split. If a list - of length 2, use the first element for the train split and the second for - the val split. If empty, use all examples. Defaults to empty. - num_shots: The number of examples to use in few-shot prompts. If zero, prompts - are zero-shot. Defaults to 0. - num_variants: The number of prompt templates to apply to each predicate upon - call to __getitem__. Use -1 to apply all available templates. Defaults to 1. - seed: The seed to use for prompt randomization. Defaults to 42. - """ - - datasets: list[str] = field(positional=True) - data_dirs: list[str] = field(default_factory=list) - max_examples: list[int] = field(default_factory=lambda: [1000, 1000]) - num_shots: int = 0 - num_variants: int = -1 - seed: int = 42 - - def __post_init__(self): - if len(self.max_examples) > 2: - raise ValueError( - "max_examples should be a list of length 0, 1, or 2," - f"but got {len(self.max_examples)}" - ) - if not self.max_examples: - self.max_examples = [int(1e100)] - - # Broadcast the limit to all splits - if len(self.max_examples) == 1: - self.max_examples *= 2 - - # Broadcast the dataset name to all data_dirs - if len(self.data_dirs) == 1: - self.data_dirs *= len(self.datasets) - elif self.data_dirs and len(self.data_dirs) != len(self.datasets): - raise ValueError( - "data_dirs should be a list of length 0, 1, or len(datasets)," - f" but got {len(self.data_dirs)}" - ) - - def explode(self) -> list["PromptConfig"]: - """Explode the config into a list of configs, one for each dataset.""" - copies = [] - - for ds, data_dir in zip_longest(self.datasets, self.data_dirs): - copy = deepcopy(self) - copy.datasets = [ds] - copy.data_dirs = [data_dir] if data_dir else [] - copies.append(copy) - - return copies - - def load_prompts( ds_string: str, *, @@ -85,39 +20,40 @@ def load_prompts( num_variants: int = -1, seed: int = 42, split_type: Literal["train", "val"] = "train", + template_path: str | None = None, rank: int = 0, world_size: int = 1, ) -> Iterator[dict]: """Load a dataset full of prompts generated from the specified dataset. Args: - ds_string: Space-delimited name of the HuggingFace dataset to use, - e.g. `"super_glue boolq"` or `"imdb"`. + ds_string: Name of HF dataset to use, e.g. `"super_glue:boolq"` or `"imdb"`. num_shots: The number of examples to use in few-shot prompts. If zero, prompts are zero-shot. seed: The seed to use for prompt randomization. split_type: Whether to use the train or val split of the dataset. + template_path: Path to feed into `DatasetTemplates` for loading templates. rank: The rank of the current process. Defaults to 0. world_size: The number of processes. Defaults to 1. Returns: An iterable of prompt dictionaries. """ - ds_name, _, config_name = ds_string.partition(" ") - prompter = DatasetTemplates(ds_name, config_name) - prompter.drop_non_mc_templates() + ds_name, _, config_name = ds_string.partition(":") ds_dict = assert_type(dict, load_dataset(ds_name, config_name or None)) - train_name, val_name = select_train_val_splits(ds_dict) - split_name = val_name if split_type == "val" else train_name + split_name = select_split(ds_dict, split_type) - ds = ds_dict[split_name].shuffle(seed=seed) - train_ds = ds_dict[train_name].shuffle(seed=seed) - - ds = assert_type(Dataset, ds) + ds = assert_type(Dataset, ds_dict[split_name].shuffle(seed=seed)) if world_size > 1: ds = ds.shard(world_size, rank) + if template_path is None: + prompter = DatasetTemplates(ds_name, config_name) + else: + prompter = DatasetTemplates(template_path) + + prompter.drop_non_mc_templates() num_templates = len(prompter.templates) num_variants = ( num_templates if num_variants == -1 else min(num_variants, num_templates) @@ -126,47 +62,27 @@ def load_prompts( if rank == 0: print(f"Using {num_variants} variants of each prompt") - # Which classes are actually present in this split of the dataset? - # This is shockingly fast since it uses an optimized Apache Arrow primitive. label_column = prompter.label_column or infer_label_column(ds.features) - observed_labels = set(ds.unique(label_column)) - - # Now sanity check that the observed classes match the expected classes. This can - # sometimes fail if we picked an unlabeled split (e.g. everything is -1) - label_feature = ds.features[label_column] - if isinstance(label_feature, ClassLabel): - label_choices = {label_feature.str2int(label) for label in label_feature.names} - elif isinstance(label_feature, Value) and label_feature.dtype == "bool": - label_choices = {False, True} - else: - # We just have to assume that the observed labels are right - label_choices = observed_labels - - if observed_labels != label_choices: - raise ValueError( - f"Observed labels {observed_labels} in split '{split_name}' do not match " - f"expected labels {label_choices} from the dataset features." - ) - - if prompt_choices := prompter.label_choices: - # The observed labels should be a superset of the prompt choices - if not (observed_labels >= set(prompt_choices)): - raise ValueError( - f"Observed labels {observed_labels} in split '{split_name}' do not " - f"match the prompt choices {prompt_choices}." - ) - - sorted_labels = prompt_choices - else: - # Impose a canonical order on the label choices. Theoretically the label column - # may be of a type that doesn't support comparison (so Pylance complains), but - # we'll just let it raise an exception if that happens. - sorted_labels = sorted(label_choices) # type: ignore[arg-type] + label_choices = prompter.label_choices + + if not label_choices: + label_feature = ds.features[label_column] + if isinstance(label_feature, ClassLabel): + label_choices = [ + label_feature.str2int(label) for label in label_feature.names + ] + elif isinstance(label_feature, Value) and label_feature.dtype == "bool": + label_choices = [False, True] + else: + # Which classes are actually present in this split of the dataset? + # This is shockingly fast since it uses an optimized Apache Arrow primitive. + label_choices = sorted(ds.unique(label_column)) rng = Random(seed) if num_shots > 0: + train_name = select_split(ds_dict, "train") fewshot = FewShotSampler( - train_ds, # TODO: not iterator + ds_dict[train_name].shuffle(seed=seed), # TODO: not iterator num_shots=num_shots, rng=rng, ) @@ -174,17 +90,26 @@ def load_prompts( else: fewshot_iter = None - ds = ds.to_iterable_dataset() + if label_column in ds.features: + ds = BalancedSampler( + ds.to_iterable_dataset(), + set(label_choices), + label_col=label_column, + strict=False, + ) + else: + if rank == 0: + print("No label column found, not balancing") + ds = ds.to_iterable_dataset() + if rank == 0: - print(f"Label choices: {sorted_labels}") + print(f"Label choices: {label_choices}") - for example in BalancedSampler( - ds, set(sorted_labels), label_col=label_column, strict=False - ): + for example in ds: yield _convert_to_prompts( example, label_column=label_column, - label_choices=sorted_labels, # type: ignore[arg-type] + label_choices=label_choices, # type: ignore[arg-type] num_variants=num_variants, prompter=prompter, rng=rng, diff --git a/elk/files.py b/elk/files.py index e8d0a517..e7e917e0 100644 --- a/elk/files.py +++ b/elk/files.py @@ -40,6 +40,6 @@ def memorably_named_dir(parent: Path): return out_dir -def transfer_eval_directory(source: str) -> Path: +def transfer_eval_directory(source: Path) -> Path: """Return the directory where transfer evals are stored.""" return elk_reporter_dir() / source / "transfer_eval" diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index f70c5520..e4baf703 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -19,7 +19,7 @@ env = Environment(loader=BaseLoader) # type: ignore # Allow the python function zip() -env.globals.update(zip=zip) +env.globals.update(enumerate=enumerate, zip=zip) # These are users whose datasets should be included in the results returned by # filter_english_datasets (regardless of their metadata) @@ -34,6 +34,18 @@ def choice(choices): return random.choice(choices) +def permutation(n): + return random.sample(range(n), n) + + +def reorder(arr, permutation): + return [arr[i] for i in permutation] + + +def to_letter(n): + return chr(n + ord("A")) + + def most_frequent(items): """Returns the set of items which appear most frequently in the input""" if not items: @@ -47,6 +59,9 @@ def most_frequent(items): env.filters["highlight"] = highlight env.filters["choice"] = choice env.filters["most_frequent"] = most_frequent +env.filters["permutation"] = permutation +env.filters["reorder"] = reorder +env.filters["to_letter"] = to_letter class Template(yaml.YAMLObject): @@ -386,7 +401,7 @@ class DatasetTemplates: TEMPLATE_FILENAME = "templates.yaml" label_column: str | None - label_choices: list[str] + label_choices: list[int | str] def __init__(self, dataset_name: str, subset_name: str | None = None): self.dataset_name = dataset_name @@ -409,7 +424,7 @@ def __init__(self, dataset_name: str, subset_name: str | None = None): def drop_non_mc_templates(self) -> int: """Drop all templates that aren't multiple choice, return the number dropped""" mc_templates = { - k: v for k, v in self.templates.items() if v.get_fixed_answer_choices_list() + k: v for k, v in self.templates.items() if v.answer_choices is not None } if not mc_templates: raise ValueError("No multiple choice templates found") diff --git a/elk/promptsource/templates/sciq_binary/templates.yaml b/elk/promptsource/templates/sciq_binary/templates.yaml new file mode 100644 index 00000000..e83f5a71 --- /dev/null +++ b/elk/promptsource/templates/sciq_binary/templates.yaml @@ -0,0 +1,160 @@ +# Meta-template for binarizing a multiple choice dataset +dataset: sciq_binary +label_choices: + - 0 + - 1 +# This column doesn't actually exist in the sciq dataset, but that's okay. We end up +# creating this column in the output dataset, and never directly read it from the +# input dataset. +label_column: label +templates: + a1: !Template + answer_choices: '{{distractor1}} ||| {{distractor2}} ||| {{distractor3}} ||| {{correct_answer}}' + id: a1 + jinja: "\ + {% set order = 4 | permutation %} + {% set response_idx = order.index(3 if label else (range(3) | choice)) %} + {% set shuffled_choices = answer_choices | reorder(order) %} + You will see a question followed by a response which may be incorrect. + You will be asked to figure out whether the answer is correct or not. + + + \"Q: {{question}} + + {% for i, c in enumerate(shuffled_choices) %} + {{i | to_letter}}) {{c}} + + {% endfor %} + Response: {{response_idx | to_letter}}, {{shuffled_choices[response_idx]}}\" + + + Is the response correct, yes or no?|||{{'yes' if label else 'no'}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: yes_no + reference: '' + a2: !Template + answer_choices: '{{distractor1}} ||| {{distractor2}} ||| {{distractor3}} ||| {{correct_answer}}' + id: a2 + jinja: "\ + {% set order = 4 | permutation %} + {% set response_idx = order.index(3 if label else (range(3) | choice)) %} + {% set shuffled_choices = answer_choices | reorder(order) %} + The following quote contains a question followed by a response which may be + incorrect. Your task is to determine whether the answer is correct or incorrect. + + + \"Q: {{question}} + + {% for i, c in enumerate(shuffled_choices) %} + {{i | to_letter}}) {{c}} + + {% endfor %} + Response: {{response_idx | to_letter}}, {{shuffled_choices[response_idx]}}\" + + + Was the response correct or incorrect?|||{{'correct' if label else 'incorrect'}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: correct_incorrect + reference: '' + a3: !Template + answer_choices: '{{distractor1}} ||| {{distractor2}} ||| {{distractor3}} ||| {{correct_answer}}' + id: a3 + jinja: "\ + {% set order = 4 | permutation %} + {% set response_idx = order.index(3 if label else (range(3) | choice)) %} + {% set shuffled_choices = answer_choices | reorder(order) %} + The following quote is an exam question followed by a student's response. + You will determine whether student chose the right answer. + + + \"Q: {{question}} + + {% for i, c in enumerate(shuffled_choices) %} + {{i | to_letter}}) {{c}} + + {% endfor %} + Student's response: {{response_idx | to_letter}}, {{shuffled_choices[response_idx]}}\" + + + Did the student choose the right answer, yes or no?|||{{'yes' if label else 'no'}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: student_yes_no + reference: '' + a4: !Template + answer_choices: '{{distractor1}} ||| {{distractor2}} ||| {{distractor3}} ||| {{correct_answer}}' + id: a4 + jinja: "\ + {% set order = 4 | permutation %} + {% set response_idx = order.index(3 if label else (range(3) | choice)) %} + {% set shuffled_choices = answer_choices | reorder(order) %} + The paragraph below is a quiz question followed by a student's response. + You will determine if the response was right or wrong. + + + \"Q: {{question}} + + {% for i, c in enumerate(shuffled_choices) %} + {{i | to_letter}}) {{c}} + + {% endfor %} + Student's response: {{response_idx | to_letter}}, {{shuffled_choices[response_idx]}}\" + + + Was the response right or wrong?|||{{'right' if label else 'wrong'}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: student_correct_incorrect + reference: '' + a5: !Template + answer_choices: '{{distractor1}} ||| {{distractor2}} ||| {{distractor3}} ||| {{correct_answer}}' + id: a5 + jinja: "\ + {% set order = 4 | permutation %} + {% set response_idx = order.index(3 if label else (range(3) | choice)) %} + {% set shuffled_choices = answer_choices | reorder(order) %} + The following is an exam question followed by a student's response. + You will guess if they received full credit for the answer. + + + \"Q: {{question}} + + {% for i, c in enumerate(shuffled_choices) %} + {{i | to_letter}}) {{c}} + + {% endfor %} + Student's response: {{response_idx | to_letter}}, {{shuffled_choices[response_idx]}}\" + + + Did the student get full credit for this answer (yes or no)?|||{{'yes' if label else 'no'}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: student_full_credit + reference: '' diff --git a/elk/run.py b/elk/run.py index c9a9fef8..b9e88995 100644 --- a/elk/run.py +++ b/elk/run.py @@ -22,9 +22,9 @@ from .files import elk_reporter_dir, memorably_named_dir from .utils import ( assert_type, - get_layers, + get_layer_indices, int16_to_float32, - select_train_val_splits, + select_split, select_usable_devices, ) @@ -66,7 +66,7 @@ def execute( if self.out_dir is None: # Save in a memorably-named directory inside of # ELK_REPORTER_DIR// - ds_name = ", ".join(self.data.prompts.datasets) + ds_name = ", ".join(self.data.datasets) root = elk_reporter_dir() / self.data.model / ds_name self.out_dir = memorably_named_dir(root) @@ -123,8 +123,7 @@ def prepare_data( out = {} for ds_name, ds in self.datasets: - train_name, val_name = select_train_val_splits(ds) - key = train_name if split_type == "train" else val_name + key = select_split(ds, split_type) split = ds[key].with_format("torch", device=device, dtype=torch.int16) labels = assert_type(Tensor, split["label"]) @@ -160,7 +159,7 @@ def apply_to_layers( """ self.out_dir = assert_type(Path, self.out_dir) - layers, *rest = [get_layers(ds) for _, ds in self.datasets] + layers, *rest = [get_layer_indices(ds) for _, ds in self.datasets] assert all(x == layers for x in rest), "All datasets must have the same layers" if self.concatenated_layer_offset > 0: diff --git a/elk/training/sweep.py b/elk/training/sweep.py index cfe7d4e4..8b49cdf4 100644 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -1,8 +1,8 @@ from copy import deepcopy -from dataclasses import InitVar, dataclass +from dataclasses import InitVar, dataclass, replace from ..evaluation.evaluate import Eval -from ..extraction import Extract, PromptConfig +from ..extraction import Extract from ..files import elk_reporter_dir, memorably_named_dir from ..utils import colorize from .train import Elicit @@ -25,7 +25,7 @@ class Sweep: run_template: Elicit = Elicit( data=Extract( model="", - prompts=PromptConfig(datasets=[""]), + datasets=("",), ) ) @@ -63,8 +63,7 @@ def execute(self): ) for i, model_str in enumerate(self.models): - # Magenta color for the model name - print(f"\n\033[35m===== {model_str} ({i + 1} of {M}) =====\033[0m") + print(colorize(f"===== {model_str} ({i + 1} of {M}) =====", "magenta")) for dataset_str in self.datasets: out_dir = sweep_dir / model_str / dataset_str @@ -72,11 +71,10 @@ def execute(self): # Allow for multiple datasets to be specified in a single string with # plus signs. This means we can pool datasets together inside of a # single sweep. - train_datasets = [ds.strip() for ds in dataset_str.split("+")] + train_datasets = tuple(ds.strip() for ds in dataset_str.split("+")) run = deepcopy(self.run_template) - run.data.model = model_str - run.data.prompts.datasets = train_datasets + run.data = replace(run.data, model=model_str, datasets=train_datasets) run.out_dir = out_dir run.execute() @@ -89,13 +87,13 @@ def execute(self): if eval_dataset in train_datasets: continue - data = deepcopy(run.data) - data.model = model_str - data.prompts.datasets = [eval_dataset] - eval = Eval( - data=data, - source=str(run.out_dir), + data=replace( + run.data, model=model_str, datasets=(eval_dataset,) + ), + source=run.out_dir, out_dir=out_dir, + num_gpus=run.num_gpus, + min_gpu_mem=run.min_gpu_mem, ) eval.execute(highlight_color="green") diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index fe911ff4..bc7b7d15 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -1,7 +1,7 @@ from .data_utils import ( binarize, get_columns_all_equal, - get_layers, + get_layer_indices, has_multiple_configs, infer_label_column, infer_num_classes, @@ -23,7 +23,7 @@ "cov_mean_fused", "float32_to_int16", "get_columns_all_equal", - "get_layers", + "get_layer_indices", "has_multiple_configs", "infer_label_column", "infer_num_classes", diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 6e3c218d..b574d12e 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -7,7 +7,6 @@ ClassLabel, DatasetDict, Features, - Split, Value, get_dataset_config_names, ) @@ -15,13 +14,6 @@ from ..promptsource.templates import Template from .typing import assert_type -# Lower values are more "train-like" and higher values are more "test-like". -PRIORITIES = { - Split.TRAIN: 0, - Split.VALIDATION: 1, - Split.TEST: 2, -} - def get_columns_all_equal(dataset: DatasetDict) -> list[str]: """Get columns of a `DatasetDict`, asserting all splits have the same columns.""" @@ -33,6 +25,18 @@ def get_columns_all_equal(dataset: DatasetDict) -> list[str]: return pivot +def get_split_priority(split: str) -> int: + """Return an integer indicating how "test-like" a split is given its name.""" + if split.startswith("train"): + return 0 + elif split.startswith("val"): + return 1 + elif split.startswith("test"): + return 2 + + return 3 + + @cache def has_multiple_configs(ds_name: str) -> bool: """Return whether a dataset has multiple configs.""" @@ -41,16 +45,22 @@ def has_multiple_configs(ds_name: str) -> bool: def select_split(raw_splits: Iterable[str], split_type: Literal["train", "val"]) -> str: """Return the train or validation split to use, given an Iterable of splits.""" - assert split_type in ("train", "val") - - reduce_fn = min if split_type == "train" else max - return reduce_fn(raw_splits, key=lambda k: PRIORITIES.get(k, 100)) # type: ignore + assert split_type in ("train", "val"), f"Invalid split type: {split_type}" + + # Note we use the alphabetical order of the splits as a tiebreaker. + sorted_splits = sorted(raw_splits, key=lambda k: (get_split_priority(k), k)) + if not sorted_splits: + raise ValueError("No splits found!") + elif len(sorted_splits) == 1: + return sorted_splits[0] + else: + return sorted_splits[0] if split_type == "train" else sorted_splits[1] def select_train_val_splits(raw_splits: Iterable[str]) -> tuple[str, str]: """Return splits to use for train and validation, given an Iterable of splits.""" - splits = sorted(raw_splits, key=lambda k: PRIORITIES.get(k, 100)) # type: ignore + splits = sorted(raw_splits, key=lambda k: (get_split_priority(k), k)) assert len(splits) >= 2, "Must have at least two of train, val, and test splits" return tuple(splits[:2]) @@ -117,15 +127,14 @@ def infer_num_classes(label_feature: Any) -> int: ) -def get_layers(ds: DatasetDict) -> list[int]: - """Get a list of indices of hidden layers given a `DatasetDict`.""" - train, _ = select_train_val_splits(ds.keys()) - layers = [ - int(feat[len("hidden_") :]) - for feat in ds[train].features - if feat.startswith("hidden_") - ] - return layers +def get_layer_indices(ds: DatasetDict) -> list[int]: + """Return the indices of the layers from which the hiddens have been extracted.""" + # Dataset has a bunch of columns of the form "hidden_0", "hidden_1", etc. + # str.removeprefix() is a no-op if the prefix isn't present + suffixes = (col.removeprefix("hidden_") for col in get_columns_all_equal(ds)) + + # Convert to the suffixes that are integral to ints, then sort them + return sorted(int(suffix) for suffix in suffixes if suffix.isdigit()) def binarize(template: Template, label: int, new_label: int, rng: Random) -> Template: diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index e35d04de..a4294298 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -3,6 +3,7 @@ import os import time import warnings +from functools import cache import pynvml import torch @@ -10,6 +11,13 @@ from .typing import assert_type +# We cache the results primarily so that we don't display "Using N of M GPUs..." +# multiple times during the same run. This does sort of assume that once we identify +# a GPU as being available, it will remain available for the duration of the run. +# This seems to be a reasonable assumption because PyTorch tends to hold onto VRAM +# for later use once it's been allocated. Calling torch.cuda.empty_cache() might break +# this assumption, but we never do that. +@cache def select_usable_devices( num_gpus: int = -1, *, min_memory: int | None = None ) -> list[str]: diff --git a/elk/utils/pretty.py b/elk/utils/pretty.py index 6552dc1b..b502cdcf 100644 --- a/elk/utils/pretty.py +++ b/elk/utils/pretty.py @@ -1,4 +1,4 @@ -# Kind of kickass that this file has no imports +from typing import Literal # ANSI color codes for use in terminal output. COLOR_CODES = { @@ -11,9 +11,10 @@ "cyan": 36, "white": 37, } +Color = Literal["black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"] -def colorize(message: str, color: str) -> str: +def colorize(message: str, color: Color) -> str: """Colorize a message for terminal output.""" # Get the ANSI color code based on the human-readable color name. code = COLOR_CODES.get(color.lower()) diff --git a/tests/dbpedia_prompts.yaml b/tests/dbpedia_prompts.yaml index 76bc18b5..4968289f 100644 --- a/tests/dbpedia_prompts.yaml +++ b/tests/dbpedia_prompts.yaml @@ -1,6 +1,7 @@ -balance: true datasets: - "dbpedia_14" +model: + - gpt2 label_column: null max_examples: - 5 diff --git a/tests/super_glue_prompts.yaml b/tests/super_glue_prompts.yaml index 196267af..effbeb6c 100644 --- a/tests/super_glue_prompts.yaml +++ b/tests/super_glue_prompts.yaml @@ -1,7 +1,8 @@ -balance: true datasets: - "super_glue boolq" - "super_glue copa" +model: + - gpt2 label_column: null max_examples: - 5 diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py index d8c065cf..0d309cf5 100644 --- a/tests/test_load_prompts.py +++ b/tests/test_load_prompts.py @@ -3,18 +3,18 @@ import pytest -from elk.extraction import PromptConfig, load_prompts +from elk.extraction import Extract, load_prompts from elk.promptsource.templates import DatasetTemplates @pytest.mark.filterwarnings("ignore:Unable to find a decoding function") def test_load_prompts(): - def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): + def test_single_split(cfg: Extract, split_type: Literal["train", "val"]): for cfg in cfg.explode(): ds_string = cfg.datasets[0] prompt_ds = load_prompts(ds_string, split_type=split_type) - ds_name, _, config_name = ds_string.partition(" ") + ds_name, _, config_name = ds_string.partition(":") prompter = DatasetTemplates(ds_name, config_name or None) prompter.drop_non_mc_templates() @@ -30,11 +30,11 @@ def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): # the case where the dataset has 2 classes # this dataset is small - cfg = PromptConfig.load_yaml("tests/super_glue_prompts.yaml") + cfg = Extract.load_yaml("tests/super_glue_prompts.yaml") test_single_split(cfg, "train") test_single_split(cfg, "val") # the case where the dataset has more than 2 classes - cfg = PromptConfig.load_yaml("tests/dbpedia_prompts.yaml") + cfg = Extract.load_yaml("tests/dbpedia_prompts.yaml") test_single_split(cfg, "train") test_single_split(cfg, "val") diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index e61568ba..7cf0e8c9 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -1,7 +1,6 @@ from pathlib import Path from elk import Extract -from elk.extraction import PromptConfig from elk.training import CcsReporterConfig, EigenReporterConfig from elk.training.train import Elicit @@ -13,7 +12,8 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): elicit = Elicit( data=Extract( model=model_path, - prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), + datasets=(dataset_name,), + max_examples=(10, 10), # run on all layers, tiny-gpt only has 2 layers ), num_gpus=2, @@ -43,7 +43,8 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): elicit = Elicit( data=Extract( model=model_path, - prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), + datasets=(dataset_name,), + max_examples=(10, 10), # run on all layers, tiny-gpt only has 2 layers ), num_gpus=2, diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index e10609a8..cf1827f3 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -1,11 +1,9 @@ from pathlib import Path -from typing import Sequence import pandas as pd from elk import Extract from elk.evaluation import Eval -from elk.extraction import PromptConfig from elk.extraction.dataset_name import extract_dataset_name_and_config from elk.files import transfer_eval_directory from elk.training import CcsReporterConfig, EigenReporterConfig @@ -32,7 +30,8 @@ def setup_elicit( elicit = Elicit( data=Extract( model=model_path, - prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), + datasets=(dataset_name,), + max_examples=(10, 10), # run on all layers, tiny-gpt only has 2 layers ), num_gpus=2, @@ -52,27 +51,30 @@ def check_contains_files(dir: Path, expected_files: list[str]): assert file in created_file_names -def eval_run(elicit: Elicit, transfer_datasets: Sequence[str] = []) -> int: +def eval_run(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()) -> int: """A single eval run; act and assert that expected files were created. Returns a reference time (in seconds) for file modification checking. """ tmp_path = elicit.out_dir extract = elicit.data + assert tmp_path is not None # record elicit modification time as reference. start_time_sec = (tmp_path / "eval.csv").stat().st_mtime if transfer_datasets: # update datasets to a different dataset - extract.prompts.datasets = transfer_datasets + extract.datasets = transfer_datasets eval = Eval(data=extract, source=tmp_path) eval.execute() return start_time_sec -def eval_assert_files_created(elicit: Elicit, transfer_datasets: Sequence[str] = []): +def eval_assert_files_created(elicit: Elicit, transfer_datasets: tuple[str, ...] = ()): tmp_path = elicit.out_dir + assert tmp_path is not None + eval_dir = transfer_eval_directory(source=tmp_path) assert eval_dir.exists(), f"transfer eval dir {eval_dir} does not exist" check_contains_files(eval_dir, EVAL_EXPECTED_FILES) @@ -92,20 +94,20 @@ def eval_assert_files_created(elicit: Elicit, transfer_datasets: Sequence[str] = def test_smoke_tfr_eval_run_tiny_gpt2_ccs(tmp_path: Path): elicit = setup_elicit(tmp_path) - transfer_datasets = ["christykoh/imdb_pt"] + transfer_datasets = ("christykoh/imdb_pt",) eval_run(elicit, transfer_datasets=transfer_datasets) eval_assert_files_created(elicit, transfer_datasets=transfer_datasets) def test_smoke_eval_run_tiny_gpt2_eigen(tmp_path: Path): elicit = setup_elicit(tmp_path, is_ccs=False) - transfer_datasets = ["christykoh/imdb_pt"] + transfer_datasets = ("christykoh/imdb_pt",) eval_run(elicit, transfer_datasets=transfer_datasets) eval_assert_files_created(elicit, transfer_datasets=transfer_datasets) def test_smoke_multi_eval_run_tiny_gpt2_ccs(tmp_path: Path): elicit = setup_elicit(tmp_path) - transfer_datasets = ["christykoh/imdb_pt", "super_glue boolq"] + transfer_datasets = ("christykoh/imdb_pt", "super_glue:boolq") eval_run(elicit, transfer_datasets=transfer_datasets) eval_assert_files_created(elicit, transfer_datasets=transfer_datasets) From 960ff01e9536757f0b35c1341ece5e70789f95ac Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 27 Apr 2023 07:01:41 +0000 Subject: [PATCH 13/17] Support --binarize again --- elk/evaluation/evaluate.py | 11 +- elk/extraction/extraction.py | 19 +- elk/extraction/prompt_loading.py | 13 ++ elk/files.py | 5 - elk/promptsource/templates.py | 330 +++++-------------------------- elk/run.py | 2 +- elk/training/__init__.py | 2 + elk/training/sweep.py | 24 +-- elk/utils/__init__.py | 5 +- elk/utils/data_utils.py | 33 ---- tests/test_smoke_eval.py | 3 +- 11 files changed, 92 insertions(+), 355 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index cf0508df..cbed7cab 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -6,7 +6,7 @@ import torch from simple_parsing.helpers import field -from ..files import elk_reporter_dir, transfer_eval_directory +from ..files import elk_reporter_dir from ..metrics import evaluate_preds from ..run import Run from ..training import Reporter @@ -25,16 +25,13 @@ class Eval(Run): def __post_init__(self): assert self.source, "Must specify a source experiment." - # Set the output directory to the transfer directory if it's not specified - self.out_dir = ( - transfer_eval_directory(self.source) - if self.out_dir is None - else self.out_dir - ) + if not self.out_dir: + self.out_dir = self.source / "transfer" / "+".join(self.data.datasets) def execute(self, highlight_color: str = "cyan"): return super().execute(highlight_color, split_type="val") + @torch.inference_mode() def apply_to_layer( self, layer: int, devices: list[str], world_size: int ) -> dict[str, pd.DataFrame]: diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 631bb5df..3ecfed30 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -27,6 +27,7 @@ from ..promptsource import DatasetTemplates from ..utils import ( + Color, assert_type, colorize, float32_to_int16, @@ -60,6 +61,9 @@ class Extract(Serializable): data_dirs: tuple[str, ...] = () """Directory to use for caching the hiddens. Defaults to `HF_DATASETS_CACHE`.""" + binarize: bool = False + """Whether to binarize the dataset labels for multi-class datasets.""" + max_examples: tuple[int, int] = (1000, 1000) """Maximum number of examples to use from each split of the dataset.""" @@ -70,9 +74,6 @@ class Extract(Serializable): """The number of prompt templates to use for each example. If -1, all available templates are used.""" - seed: int = 42 - """Seed to use for prompt randomization. Defaults to 42.""" - layers: tuple[int, ...] = () """Indices of layers to extract hidden states from. We follow the HF convention, so 0 is the embedding, and 1 is the output of the first transformer layer.""" @@ -80,6 +81,9 @@ class Extract(Serializable): layer_stride: InitVar[int] = 1 """Shortcut for `layers = (0,) + tuple(range(1, num_layers + 1, stride))`.""" + seed: int = 42 + """Seed to use for prompt randomization. Defaults to 42.""" + template_path: str | None = None """Path to pass into `DatasetTemplates`. By default we use the dataset name.""" @@ -170,6 +174,7 @@ def extract_hiddens( prompt_ds = load_prompts( ds_names[0], + binarize=cfg.binarize, split_type=split_type, template_path=cfg.template_path, rank=rank, @@ -322,8 +327,10 @@ def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]: ds_features = assert_type(Features, info.features) label_col = prompter.label_column or infer_label_column(ds_features) - num_classes = len(prompter.label_choices) or infer_num_classes( - ds_features[label_col] + num_classes = ( + 2 + if cfg.binarize + else (len(prompter.label_choices) or infer_num_classes(ds_features[label_col])) ) num_variants = cfg.num_variants @@ -369,7 +376,7 @@ def extract( cfg: "Extract", *, disable_cache: bool = False, - highlight_color: str = "cyan", + highlight_color: Color = "cyan", num_gpus: int = -1, min_gpu_mem: int | None = None, split_type: Literal["train", "val", None] = None, diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 2d17bdf4..9424bbcf 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -16,6 +16,7 @@ def load_prompts( ds_string: str, *, + binarize: bool = False, num_shots: int = 0, num_variants: int = -1, seed: int = 42, @@ -28,6 +29,7 @@ def load_prompts( Args: ds_string: Name of HF dataset to use, e.g. `"super_glue:boolq"` or `"imdb"`. + binarize: Whether to binarize the dataset labels for multi-class datasets. num_shots: The number of examples to use in few-shot prompts. If zero, prompts are zero-shot. seed: The seed to use for prompt randomization. @@ -108,6 +110,7 @@ def load_prompts( for example in ds: yield _convert_to_prompts( example, + binarize=binarize, label_column=label_column, label_choices=label_choices, # type: ignore[arg-type] num_variants=num_variants, @@ -120,6 +123,7 @@ def load_prompts( def _convert_to_prompts( example: dict[str, Any], prompter: DatasetTemplates, + binarize: bool, label_column: str, label_choices: list[bool | int | str], num_variants: int, @@ -140,6 +144,15 @@ def qa_cat(q: str, a: str) -> str: # For sanity checking that prompts are unique prompt_counter = Counter() label = example[label_column] + if binarize: + # Replace the full list of possibilities with a randomly sampled false label + # and the correct label, as done in the DLK paper. Note that this does add some + # "supervision" by stacking the deck in favor of the correct answer. + label_choices = [ + rng.choice([c for c in label_choices if c != label]), + label, + ] + rng.shuffle(label_choices) for template in templates: choices = [] diff --git a/elk/files.py b/elk/files.py index e7e917e0..4435dee1 100644 --- a/elk/files.py +++ b/elk/files.py @@ -38,8 +38,3 @@ def memorably_named_dir(parent: Path): out_dir = parent / sub_dir out_dir.mkdir(parents=True, exist_ok=True) return out_dir - - -def transfer_eval_directory(source: Path) -> Path: - """Return the directory where transfer evals are stored.""" - return elk_reporter_dir() / source / "transfer_eval" diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index e4baf703..e32e3651 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -1,10 +1,10 @@ import os import random import uuid -from collections import Counter, defaultdict +from collections import Counter +from dataclasses import dataclass from pathlib import Path -from shutil import rmtree -from typing import Optional +from typing import Any, ClassVar import yaml from jinja2 import BaseLoader, Environment, meta @@ -30,10 +30,6 @@ def highlight(input): return "" + input + "" -def choice(choices): - return random.choice(choices) - - def permutation(n): return random.sample(range(n), n) @@ -57,7 +53,7 @@ def most_frequent(items): env.filters["highlight"] = highlight -env.filters["choice"] = choice +env.filters["choice"] = random.choice env.filters["most_frequent"] = most_frequent env.filters["permutation"] = permutation env.filters["reorder"] = reorder @@ -100,45 +96,13 @@ def __init__(self, name, jinja, reference, metadata=None, answer_choices=None): self.metadata = metadata if metadata is not None else Template.Metadata() self.answer_choices = answer_choices - def get_id(self): - """ - Returns the id of the template - - :return: unique id for template - """ - return self.id - - def get_name(self): - """ - Returns the name of the template - - :return: unique (per dataset) name for template - """ - return self.name - - def get_reference(self): - """ - Returns the bibliographic reference (or author) for the template - - :return: reference as a string - """ - return self.reference - - def get_answer_choices_expr(self): - """ - Returns a Jinja expression for computing the answer choices from an example. - - :return: String, or None if no answer choices - """ - return self.answer_choices - def get_answer_choices_list(self, example): """ Returns a list of answer choices for a given example :return: list of strings, or None if get_answer_choices_expr is None """ - jinja = self.get_answer_choices_expr() + jinja = self.answer_choices if jinja is None: return None @@ -155,7 +119,7 @@ def get_fixed_answer_choices_list(self): Returns a list of answer choices that is static across examples, if possible :return: list of strings, or None if no static list exists """ - jinja = self.get_answer_choices_expr() + jinja = self.answer_choices if jinja is None: return None @@ -190,8 +154,8 @@ def apply(self, example, truncate=True, highlight_variables=False): # Highlights text that was substituted for variables, if requested if highlight_variables: jinja = jinja.replace("}}", " | highlight }}") - rtemplate = env.from_string(jinja) + rtemplate = env.from_string(jinja) protected_example = self._escape_pipe(example) # Adds in answer_choices variable @@ -210,6 +174,29 @@ def apply(self, example, truncate=True, highlight_variables=False): for part in rendered_example.split("|||") ] + def contrast_set( + self, example: dict[str, Any], label_key: str, pseudo_labels: list + ) -> tuple[str, list[str]]: + # Record the RNG state so that any non-deterministic filters in the template + # will not affect the contrast set + rng_state = random.getstate() + + answers = [] + questions = set() + + for pseudo_label in pseudo_labels: + pseudo_example = example.copy() + pseudo_example[label_key] = pseudo_label + + random.setstate(rng_state) + q, a = self.apply(pseudo_example) + answers.append(a) + questions.add(q) + + breakpoint() + assert len(questions) == 1, "Contrast set questions must be identical" + return questions.pop(), answers + @staticmethod def _strip_spaces(string): """Same functionality as str.strip(), but ignores newlines""" @@ -256,135 +243,25 @@ def _unescape_pipe(cls, string): # replaces back any occurrences of the separator in a string return string.replace(cls.pipe_protector, "|||") + @dataclass class Metadata(yaml.YAMLObject): - """ - Metadata for a prompt template. - """ - - yaml_tag = "!TemplateMetadata" - - def __init__( - self, - original_task: Optional[bool] = None, - choices_in_prompt: Optional[bool] = None, - metrics: Optional[list[str]] = None, - languages: Optional[list[str]] = None, - ): - """ - Initializes template metadata. - - In the following, trivial choices are defined as Yes/No, True/False, - etc. and nontrivial choices are other types of choices denoted in - the answer_choices field. - - :param original_task: If True, this prompt asks a model to perform the - original task designed for this dataset. - :param choices_in_prompt: If True, the answer choices are included in the - templates such that models see those choices in the input. Only - applicable to classification tasks. - :param metrics: list of strings denoting metrics to use for evaluation - :param metrics: list of strings denoting languages used in the prompt - (not the associated dataset!) - """ - self.original_task = original_task - self.choices_in_prompt = choices_in_prompt - self.metrics = metrics - self.languages = languages - - -class TemplateCollection: - """ - This helper class wraps the DatasetTemplates class - - Initialized the DatasetTemplates for all existing template folder - - Give access to each DatasetTemplates - - Provides aggregated counts over all DatasetTemplates - """ - - def __init__(self): - # dict of all the DatasetTemplates, key is the tuple (dataset_name, subset_name) - self.datasets_templates = self._collect_datasets() - - @property - def keys(self): - return list(self.datasets_templates.keys()) + """Metadata for a prompt template.""" - def __len__(self) -> int: - return len(self.datasets_templates) + yaml_tag: ClassVar[str] = "!TemplateMetadata" - def remove(self, dataset_name: str, subset_name: Optional[str] = None) -> None: - del self.datasets_templates[dataset_name, subset_name] + original_task: bool | None = None + """If True, this prompt asks a model to perform the original task designed for + this dataset.""" - def _collect_datasets(self) -> dict[tuple[str, Optional[str]], "DatasetTemplates"]: - """ - Initialize a DatasetTemplates object for each templates.yaml detected in the - templates folder + choices_in_prompt: bool | None = None + """If True, the answer choices are included in the templates such that models + see those choices in the input. Only applicable to classification tasks.""" - Returns: a dict with key=(dataset_name, subset_name) - """ - dataset_folders = os.listdir(TEMPLATES_FOLDER_PATH) - dataset_folders = [ - folder for folder in dataset_folders if not folder.startswith(".") - ] + metrics: list[str] | None = None + """Strings denoting metrics to use for evaluation""" - output = {} # format is {(dataset_name, subset_name): DatasetsTemplates} - for dataset in dataset_folders: - if dataset in INCLUDED_USERS: - for filename in os.listdir( - os.path.join(TEMPLATES_FOLDER_PATH, dataset) - ): - output = { - **output, - **self._collect_dataset(dataset + "/" + filename), - } - else: - output = {**output, **self._collect_dataset(dataset)} - - return output - - def _collect_dataset(self, dataset): - output = {} # format is {(dataset_name, subset_name): DatasetsTemplates} - for filename in os.listdir(os.path.join(TEMPLATES_FOLDER_PATH, dataset)): - if filename.endswith(".yaml"): - # If there is no sub-folder, there is no subset for this dataset - output[(dataset, None)] = DatasetTemplates(dataset) - else: - # This is a subfolder, and its name corresponds to the subset name - output[(dataset, filename)] = DatasetTemplates( - dataset_name=dataset, subset_name=filename - ) - return output - - def get_dataset( - self, dataset_name: str, subset_name: Optional[str] = None - ) -> "DatasetTemplates": - """ - Return the DatasetTemplates object corresponding to the dataset name - - :param dataset_name: name of the dataset to get - :param subset_name: name of the subset - """ - # if the dataset does not exist, we add it - if dataset_name not in self.keys: - self.datasets_templates[(dataset_name, subset_name)] = DatasetTemplates( - dataset_name, subset_name - ) - - return self.datasets_templates[(dataset_name, subset_name)] - - def get_templates_count(self) -> dict: - """ - Return the overall number count over all datasets - - NB: we don't breakdown datasets into subsets for the count, i.e subsets count - are included into the dataset count - """ - - count_dict = defaultdict(int) - for k, v in self.datasets_templates.items(): - # Subsets count towards dataset count - count_dict[k[0]] += len(v) - # converting to regular dict - return dict(count_dict) + languages: list[str] | None = None + """Strings denoting languages used in the prompt""" class DatasetTemplates: @@ -402,6 +279,7 @@ class DatasetTemplates: label_column: str | None label_choices: list[int | str] + templates: dict[str, Template] def __init__(self, dataset_name: str, subset_name: str | None = None): self.dataset_name = dataset_name @@ -417,10 +295,6 @@ def __init__(self, dataset_name: str, subset_name: str | None = None): self.label_column = yaml_dict.get(self.LABEL_COLUMN_KEY) self.label_choices = yaml_dict.get(self.LABEL_CHOICES_KEY, []) - # Mapping from template name to template id - self.name_to_id_mapping = {} - self.sync_mapping() - def drop_non_mc_templates(self) -> int: """Drop all templates that aren't multiple choice, return the number dropped""" mc_templates = { @@ -431,18 +305,9 @@ def drop_non_mc_templates(self) -> int: num_dropped = len(self.templates) - len(mc_templates) self.templates = mc_templates - self.sync_mapping() return num_dropped - def sync_mapping(self) -> None: - """ - Re-compute the name_to_id_mapping to ensure it is in sync with self.templates - """ - self.name_to_id_mapping = { - template.name: template.id for template in self.templates.values() - } - @property def all_template_names(self) -> list[str]: """ @@ -466,110 +331,3 @@ def yaml_path(self) -> str: raise ValueError(f"Expected prompt templates to exist at {path}") return path - - def format_for_dump(self) -> dict: - """ - Create a formatted dictionary for the class attributes - """ - formatted_dict = { - self.DATASET_KEY: self.dataset_name, - self.TEMPLATES_KEY: self.templates, - } - if self.subset_name: - formatted_dict[self.SUBSET_KEY] = self.subset_name - return formatted_dict - - def write_to_file(self) -> None: - """ - Writes to a file with the current prompt collection. - """ - # Sync the mapping - self.sync_mapping() - - # We only create the folder if a template is written - if not os.path.exists(self.folder_path): - os.makedirs(self.folder_path) - yaml.dump(self.format_for_dump(), open(self.yaml_path, "w")) - - def add_template(self, template: "Template") -> None: - """ - Adds a new template for the dataset - - :param template: template - """ - self.templates[template.get_id()] = template - - self.write_to_file() - - def remove_template(self, template_name: str) -> None: - """ - Deletes a template - - :param template_name: name of template to remove - """ - - # Even if we have an ID, we want to check for duplicate names - if template_name not in self.all_template_names: - raise ValueError( - f"No template with name {template_name} for dataset " - f"{self.dataset_name} exists." - ) - - del self.templates[self.name_to_id_mapping[template_name]] - - if len(self.templates) == 0: - # There is no remaining template, we can remove the entire folder - self.delete_folder() - else: - # We just update the file - self.write_to_file() - - def update_template( - self, - current_template_name: str, - new_template_name: str, - jinja: str, - reference: str, - metadata: Template.Metadata, - answer_choices: str, - ) -> None: - """ - Updates a pre-existing template and writes changes - - :param current_template_name: current name of the template stored in - self.templates - :param new_template_name: new name for the template - :param jinja: new jinja entry - :param reference: new reference entry - :param metadata: a Metadata object with template annotations - :param answer_choices: new answer_choices string - """ - template_id = self.name_to_id_mapping[current_template_name] - self.templates[template_id].name = new_template_name - self.templates[template_id].jinja = jinja - self.templates[template_id].reference = reference - self.templates[template_id].metadata = metadata - self.templates[template_id].answer_choices = answer_choices - - self.write_to_file() - - def delete_folder(self) -> None: - """ - Delete the folder corresponding to self.folder_path - """ - self.sync_mapping() - - rmtree(self.folder_path) - - # If it is a subset, we have to check whether to remove the dataset folder - if self.subset_name: - # have to check for other folders - base_folder = os.path.join(TEMPLATES_FOLDER_PATH, self.dataset_name) - if len(os.listdir(base_folder)) == 0: - rmtree(base_folder) - - def __getitem__(self, template_key: str) -> "Template": - return self.templates[self.name_to_id_mapping[template_key]] - - def __len__(self) -> int: - return len(self.templates) diff --git a/elk/run.py b/elk/run.py index b9e88995..70f0f29b 100644 --- a/elk/run.py +++ b/elk/run.py @@ -66,7 +66,7 @@ def execute( if self.out_dir is None: # Save in a memorably-named directory inside of # ELK_REPORTER_DIR// - ds_name = ", ".join(self.data.datasets) + ds_name = "+".join(self.data.datasets) root = elk_reporter_dir() / self.data.model / ds_name self.out_dir = memorably_named_dir(root) diff --git a/elk/training/__init__.py b/elk/training/__init__.py index 635644c1..390d3aa3 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -3,6 +3,7 @@ from .eigen_reporter import EigenReporter, EigenReporterConfig from .normalizer import Normalizer from .reporter import Reporter, ReporterConfig +from .train import Elicit __all__ = [ "CcsReporter", @@ -10,6 +11,7 @@ "Classifier", "EigenReporter", "EigenReporterConfig", + "Elicit", "Normalizer", "Reporter", "ReporterConfig", diff --git a/elk/training/sweep.py b/elk/training/sweep.py index 8b49cdf4..a112cd63 100644 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -1,7 +1,6 @@ -from copy import deepcopy from dataclasses import InitVar, dataclass, replace -from ..evaluation.evaluate import Eval +from ..evaluation import Eval from ..extraction import Extract from ..files import elk_reporter_dir, memorably_named_dir from ..utils import colorize @@ -12,10 +11,12 @@ class Sweep: models: list[str] """List of Huggingface model strings to sweep over.""" + datasets: list[str] """List of dataset strings to sweep over. Each dataset string can contain multiple datasets, separated by plus signs. For example, "sst2+imdb" will pool SST-2 and IMDB together.""" + add_pooled: InitVar[bool] = False """Whether to add a dataset that pools all of the other datasets together.""" @@ -62,20 +63,21 @@ def execute(self): } ) - for i, model_str in enumerate(self.models): - print(colorize(f"===== {model_str} ({i + 1} of {M}) =====", "magenta")) + for i, model in enumerate(self.models): + print(colorize(f"===== {model} ({i + 1} of {M}) =====", "magenta")) for dataset_str in self.datasets: - out_dir = sweep_dir / model_str / dataset_str + out_dir = sweep_dir / model / dataset_str # Allow for multiple datasets to be specified in a single string with # plus signs. This means we can pool datasets together inside of a # single sweep. train_datasets = tuple(ds.strip() for ds in dataset_str.split("+")) - run = deepcopy(self.run_template) - run.data = replace(run.data, model=model_str, datasets=train_datasets) - run.out_dir = out_dir + data = replace( + self.run_template.data, model=model, datasets=train_datasets + ) + run = replace(self.run_template, data=data, out_dir=out_dir) run.execute() if len(eval_datasets) > 1: @@ -88,11 +90,9 @@ def execute(self): continue eval = Eval( - data=replace( - run.data, model=model_str, datasets=(eval_dataset,) - ), + data=replace(run.data, model=model, datasets=(eval_dataset,)), source=run.out_dir, - out_dir=out_dir, + out_dir=out_dir / "transfer" / eval_dataset, num_gpus=run.num_gpus, min_gpu_mem=run.min_gpu_mem, ) diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index bc7b7d15..db9d1de4 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -1,5 +1,4 @@ from .data_utils import ( - binarize, get_columns_all_equal, get_layer_indices, has_multiple_configs, @@ -11,14 +10,14 @@ from .gpu_utils import select_usable_devices from .hf_utils import instantiate_model, instantiate_tokenizer, is_autoregressive from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained -from .pretty import colorize +from .pretty import Color, colorize from .tree_utils import pytree_map from .typing import assert_type, float32_to_int16, int16_to_float32 __all__ = [ "assert_type", "batch_cov", - "binarize", + "Color", "colorize", "cov_mean_fused", "float32_to_int16", diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index b574d12e..a468f4c3 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -1,6 +1,4 @@ -import copy from functools import cache -from random import Random from typing import Any, Iterable, Literal from datasets import ( @@ -11,7 +9,6 @@ get_dataset_config_names, ) -from ..promptsource.templates import Template from .typing import assert_type @@ -135,33 +132,3 @@ def get_layer_indices(ds: DatasetDict) -> list[int]: # Convert to the suffixes that are integral to ints, then sort them return sorted(int(suffix) for suffix in suffixes if suffix.isdigit()) - - -def binarize(template: Template, label: int, new_label: int, rng: Random) -> Template: - """Binarize a template with >2 answer choices, returning a new template and label. - Returns: - `new_template`: - A deepcopy of the original template with with 2 answer choices, one of - which is the true answer and the other is a random false answer. - `new_label`: - the index of the true answer into `new_template.answer_choices` - """ - - # TODO: it would be nice in the future to binarize exhaustively so we're not - # cheating here (since this step requires a label). e.g. this function would - # also take a candidate answer and the template would ask whether the candidate - # answer is true or false. This would require rewriting the jinja templates though. - answer_choices = assert_type(str, template.answer_choices).split(" ||| ") - assert len(answer_choices) > 2 - - true = answer_choices[label] - false = rng.choice([c for c in answer_choices if c != true]) - - assert new_label in (0, 1) - - new_template = copy.deepcopy(template) - new_template.answer_choices = ( - f"{false} ||| {true}" if new_label else f"{true} ||| {false}" - ) - - return new_template diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index cf1827f3..8669b7ea 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -5,7 +5,6 @@ from elk import Extract from elk.evaluation import Eval from elk.extraction.dataset_name import extract_dataset_name_and_config -from elk.files import transfer_eval_directory from elk.training import CcsReporterConfig, EigenReporterConfig from elk.training.train import Elicit @@ -75,7 +74,7 @@ def eval_assert_files_created(elicit: Elicit, transfer_datasets: tuple[str, ...] tmp_path = elicit.out_dir assert tmp_path is not None - eval_dir = transfer_eval_directory(source=tmp_path) + eval_dir = tmp_path / "transfer" / "+".join(transfer_datasets) assert eval_dir.exists(), f"transfer eval dir {eval_dir} does not exist" check_contains_files(eval_dir, EVAL_EXPECTED_FILES) # read "eval.csv" into a df From c9e62ea22138b2e30f0cf31f249258fbcca830b5 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 27 Apr 2023 18:32:17 +0000 Subject: [PATCH 14/17] Partial support for truthful_qa --- elk/evaluation/evaluate.py | 13 +- elk/extraction/extraction.py | 2 +- elk/extraction/prompt_loading.py | 10 +- elk/promptsource/templates.py | 18 +-- .../norabelrose/truthful_qa/templates.yaml | 118 ++++++++++++++++++ elk/run.py | 4 +- elk/utils/data_utils.py | 5 +- elk/utils/typing.py | 8 +- 8 files changed, 144 insertions(+), 34 deletions(-) create mode 100644 elk/promptsource/templates/norabelrose/truthful_qa/templates.yaml diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index cbed7cab..75d4a6bf 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -10,25 +10,21 @@ from ..metrics import evaluate_preds from ..run import Run from ..training import Reporter +from ..utils import Color -@dataclass +@dataclass(kw_only=True) class Eval(Run): """Full specification of a reporter evaluation run.""" - # Using None as a default here is a hack; we actually raise an error if it's not - # specified in __post_init__. TODO: Maybe this is an indication we should be using - # composition and not inheritance here? - source: Path | None = field(default=None, positional=True) + source: Path = field(positional=True) skip_supervised: bool = False def __post_init__(self): - assert self.source, "Must specify a source experiment." - if not self.out_dir: self.out_dir = self.source / "transfer" / "+".join(self.data.datasets) - def execute(self, highlight_color: str = "cyan"): + def execute(self, highlight_color: Color = "cyan"): return super().execute(highlight_color, split_type="val") @torch.inference_mode() @@ -39,7 +35,6 @@ def apply_to_layer( device = self.get_device(devices, world_size) val_output = self.prepare_data(device, layer, "val") - assert self.source, "Must specify a source experiment." experiment_dir = elk_reporter_dir() / self.source reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt" diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 3ecfed30..1d6554e8 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -253,7 +253,7 @@ def extract_hiddens( input_ids = input_ids[..., -min(cur_len, max_len) :] # Make sure we only pass the arguments that the model expects - inputs = dict(input_ids=input_ids) + inputs = dict(input_ids=input_ids.long()) if is_enc_dec: inputs["labels"] = answer diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 9424bbcf..9cf99d63 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -104,13 +104,11 @@ def load_prompts( print("No label column found, not balancing") ds = ds.to_iterable_dataset() - if rank == 0: - print(f"Label choices: {label_choices}") - for example in ds: yield _convert_to_prompts( example, binarize=binarize, + choices_column=prompter.choices_column, label_column=label_column, label_choices=label_choices, # type: ignore[arg-type] num_variants=num_variants, @@ -124,6 +122,7 @@ def _convert_to_prompts( example: dict[str, Any], prompter: DatasetTemplates, binarize: bool, + choices_column: str | None, label_column: str, label_choices: list[bool | int | str], num_variants: int, @@ -144,6 +143,11 @@ def qa_cat(q: str, a: str) -> str: # For sanity checking that prompts are unique prompt_counter = Counter() label = example[label_column] + if choices_column: + label_choices = example[choices_column] + if isinstance(label, int): + label_choices = list(range(len(label_choices))) + if binarize: # Replace the full list of possibilities with a randomly sampled false label # and the correct label, as done in the DLK paper. Note that this does add some diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index e32e3651..6ccf9ea8 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -270,13 +270,7 @@ class DatasetTemplates: helper functions necessary to read/write to the yaml file """ - TEMPLATES_KEY = "templates" - DATASET_KEY = "dataset" - SUBSET_KEY = "subset" - LABEL_COLUMN_KEY = "label_column" - LABEL_CHOICES_KEY = "label_choices" - TEMPLATE_FILENAME = "templates.yaml" - + choices_column: str | None label_column: str | None label_choices: list[int | str] templates: dict[str, Template] @@ -289,11 +283,11 @@ def __init__(self, dataset_name: str, subset_name: str | None = None): yaml_dict = yaml.load(f, Loader=yaml.FullLoader) # Required field; contains all the templates keyed by ID - self.templates = yaml_dict[self.TEMPLATES_KEY] + self.templates = yaml_dict["templates"] - # Optional fields; may be None - self.label_column = yaml_dict.get(self.LABEL_COLUMN_KEY) - self.label_choices = yaml_dict.get(self.LABEL_CHOICES_KEY, []) + self.choices_column = yaml_dict.get("choices_column") + self.label_column = yaml_dict.get("label_column") + self.label_choices = yaml_dict.get("label_choices", []) def drop_non_mc_templates(self) -> int: """Drop all templates that aren't multiple choice, return the number dropped""" @@ -326,7 +320,7 @@ def folder_path(self) -> str: @property def yaml_path(self) -> str: - path = os.path.join(self.folder_path, self.TEMPLATE_FILENAME) + path = os.path.join(self.folder_path, "templates.yaml") if not os.path.exists(path): raise ValueError(f"Expected prompt templates to exist at {path}") diff --git a/elk/promptsource/templates/norabelrose/truthful_qa/templates.yaml b/elk/promptsource/templates/norabelrose/truthful_qa/templates.yaml new file mode 100644 index 00000000..68dbab0d --- /dev/null +++ b/elk/promptsource/templates/norabelrose/truthful_qa/templates.yaml @@ -0,0 +1,118 @@ +# Adapted from the ARC-Easy templates +dataset: norabelrose/truthful_qa_mc +choices_column: choices +templates: + a1: !Template + answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N' + id: a1 + jinja: 'Pick the most correct option to answer the following question. + + + {{question}} + + + Options: + + {% for letter, t in zip(answer_choices, choices) %} + + - {{letter}}: {{t}} + + {% endfor %} |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: pick_the_most_correct_option + reference: '' + a2: !Template + answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N' + id: a2 + jinja: '{{question}} + + + Options: + + {% for letter, t in zip(answer_choices, choices) %} + + - {{letter}}: {{t}} + + {% endfor %} + + A:|||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: qa_options + reference: '' + a3: !Template + answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N' + id: a3 + jinja: 'I am hesitating between 4 options to answer the following question, which + option should I choose? + + Question: {{question}} + + Possibilities: + + {% for letter, t in zip(answer_choices, choices) %} + + - {{letter}}: {{t}} + + {% endfor %} + + A:|||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: i_am_hesitating + reference: '' + a4: !Template + answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N' + id: a4 + jinja: 'I gave my students this multiple choice question: {{question}} + + + Only one answer is correct among these 4 choices: + + {% for letter, t in zip(answer_choices, choices) %} + + - {{letter}}: {{t}} + + {% endfor %} + + A:|||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: multiple_choice + reference: '' + a6: !Template + answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N' + id: a6 + jinja: "Here's a problem to solve: {{question}}\n\nAmong the 4 following options,\ + \ which is the correct answer?\n{% for letter, t in zip(answer_choices, choices)\ + \ %}\n- {{letter}}: {{t}}\n {% endfor %}A:|||{{answer_choices[label]}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: heres_a_problem + reference: '' diff --git a/elk/run.py b/elk/run.py index 70f0f29b..6b24b64d 100644 --- a/elk/run.py +++ b/elk/run.py @@ -21,6 +21,7 @@ from .extraction.dataset_name import DatasetDictWithName from .files import elk_reporter_dir, memorably_named_dir from .utils import ( + Color, assert_type, get_layer_indices, int16_to_float32, @@ -48,7 +49,7 @@ class Run(ABC, Serializable): def execute( self, - highlight_color: str = "cyan", + highlight_color: Color = "cyan", split_type: Literal["train", "val", None] = None, ): self.datasets = [ @@ -127,6 +128,7 @@ def prepare_data( split = ds[key].with_format("torch", device=device, dtype=torch.int16) labels = assert_type(Tensor, split["label"]) + breakpoint() val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) with split.formatted_as("torch", device=device): diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index a468f4c3..58180ecc 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -118,10 +118,7 @@ def infer_num_classes(label_feature: Any) -> int: elif isinstance(label_feature, Value) and label_feature.dtype == "bool": return 2 else: - raise ValueError( - f"Can't infer number of classes from label column of type {label_feature}. " - f"Please update the num_classes field in the prompt template yaml file." - ) + return -1 def get_layer_indices(ds: DatasetDict) -> list[int]: diff --git a/elk/utils/typing.py b/elk/utils/typing.py index f0b10d52..b3880a3c 100644 --- a/elk/utils/typing.py +++ b/elk/utils/typing.py @@ -14,8 +14,8 @@ def assert_type(typ: Type[T], obj: Any) -> T: def float32_to_int16(x: torch.Tensor) -> torch.Tensor: - """Converts float32 to float16, then reinterprets as int16.""" - downcast = x.type(torch.float16) + """Converts float32 to bfloat16, then reinterprets as int16.""" + downcast = x.type(torch.bfloat16) if not downcast.isfinite().all(): raise ValueError("Cannot convert to 16 bit: values are not finite") @@ -23,5 +23,5 @@ def float32_to_int16(x: torch.Tensor) -> torch.Tensor: def int16_to_float32(x: torch.Tensor) -> torch.Tensor: - """Converts int16 to float16, then reinterprets as float32.""" - return x.view(torch.float16).type(torch.float32) + """Converts int16 to bfloat16, then reinterprets as float32.""" + return x.view(torch.bfloat16).type(torch.float32) From c648ff0e2e4de8e043e6a3b6e90acf0112683260 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 29 Apr 2023 23:18:40 +0000 Subject: [PATCH 15/17] Remove crap --- elk/promptsource/templates.py | 25 +------------------------ elk/run.py | 1 - 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index 6ccf9ea8..f25ec81e 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -4,7 +4,7 @@ from collections import Counter from dataclasses import dataclass from pathlib import Path -from typing import Any, ClassVar +from typing import ClassVar import yaml from jinja2 import BaseLoader, Environment, meta @@ -174,29 +174,6 @@ def apply(self, example, truncate=True, highlight_variables=False): for part in rendered_example.split("|||") ] - def contrast_set( - self, example: dict[str, Any], label_key: str, pseudo_labels: list - ) -> tuple[str, list[str]]: - # Record the RNG state so that any non-deterministic filters in the template - # will not affect the contrast set - rng_state = random.getstate() - - answers = [] - questions = set() - - for pseudo_label in pseudo_labels: - pseudo_example = example.copy() - pseudo_example[label_key] = pseudo_label - - random.setstate(rng_state) - q, a = self.apply(pseudo_example) - answers.append(a) - questions.add(q) - - breakpoint() - assert len(questions) == 1, "Contrast set questions must be identical" - return questions.pop(), answers - @staticmethod def _strip_spaces(string): """Same functionality as str.strip(), but ignores newlines""" diff --git a/elk/run.py b/elk/run.py index 6b24b64d..d9f1d9af 100644 --- a/elk/run.py +++ b/elk/run.py @@ -128,7 +128,6 @@ def prepare_data( split = ds[key].with_format("torch", device=device, dtype=torch.int16) labels = assert_type(Tensor, split["label"]) - breakpoint() val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) with split.formatted_as("torch", device=device): From ef12130b5ab1e0e859626c7ba382256656e7571d Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 29 Apr 2023 23:21:26 +0000 Subject: [PATCH 16/17] EleutherAI/truthful_qa_mc --- elk/extraction/prompt_loading.py | 6 --- elk/metrics/eval.py | 4 +- elk/promptsource/templates.py | 2 - .../truthful_qa_mc}/templates.yaml | 45 +++++++++++-------- elk/run.py | 2 +- elk/training/sweep.py | 1 + elk/utils/data_utils.py | 4 +- 7 files changed, 32 insertions(+), 32 deletions(-) rename elk/promptsource/templates/{norabelrose/truthful_qa => EleutherAI/truthful_qa_mc}/templates.yaml (69%) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 9cf99d63..18a092c5 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -108,7 +108,6 @@ def load_prompts( yield _convert_to_prompts( example, binarize=binarize, - choices_column=prompter.choices_column, label_column=label_column, label_choices=label_choices, # type: ignore[arg-type] num_variants=num_variants, @@ -122,7 +121,6 @@ def _convert_to_prompts( example: dict[str, Any], prompter: DatasetTemplates, binarize: bool, - choices_column: str | None, label_column: str, label_choices: list[bool | int | str], num_variants: int, @@ -143,10 +141,6 @@ def qa_cat(q: str, a: str) -> str: # For sanity checking that prompts are unique prompt_counter = Counter() label = example[label_column] - if choices_column: - label_choices = example[choices_column] - if isinstance(label, int): - label_choices = list(range(len(label_choices))) if binarize: # Replace the full list of possibilities with a randomly sampled false label diff --git a/elk/metrics/eval.py b/elk/metrics/eval.py index 653beae5..c7f7869f 100644 --- a/elk/metrics/eval.py +++ b/elk/metrics/eval.py @@ -72,7 +72,9 @@ def evaluate_preds( if c == 2: auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0]) else: - auroc = roc_auc_ci(to_one_hot(y_true, c).long(), y_logits) + auroc = roc_auc_ci( + to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1) + ) else: raise ValueError(f"Unknown mode: {ensembling}") diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index f25ec81e..dca68877 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -247,7 +247,6 @@ class DatasetTemplates: helper functions necessary to read/write to the yaml file """ - choices_column: str | None label_column: str | None label_choices: list[int | str] templates: dict[str, Template] @@ -262,7 +261,6 @@ def __init__(self, dataset_name: str, subset_name: str | None = None): # Required field; contains all the templates keyed by ID self.templates = yaml_dict["templates"] - self.choices_column = yaml_dict.get("choices_column") self.label_column = yaml_dict.get("label_column") self.label_choices = yaml_dict.get("label_choices", []) diff --git a/elk/promptsource/templates/norabelrose/truthful_qa/templates.yaml b/elk/promptsource/templates/EleutherAI/truthful_qa_mc/templates.yaml similarity index 69% rename from elk/promptsource/templates/norabelrose/truthful_qa/templates.yaml rename to elk/promptsource/templates/EleutherAI/truthful_qa_mc/templates.yaml index 68dbab0d..86a05118 100644 --- a/elk/promptsource/templates/norabelrose/truthful_qa/templates.yaml +++ b/elk/promptsource/templates/EleutherAI/truthful_qa_mc/templates.yaml @@ -1,9 +1,8 @@ # Adapted from the ARC-Easy templates -dataset: norabelrose/truthful_qa_mc -choices_column: choices +dataset: EleutherAI/truthful_qa_mc templates: a1: !Template - answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N' + answer_choices: 'A|||B|||C|||D' id: a1 jinja: 'Pick the most correct option to answer the following question. @@ -14,10 +13,11 @@ templates: Options: {% for letter, t in zip(answer_choices, choices) %} - - {{letter}}: {{t}} - {% endfor %} |||{{answer_choices[label]}}' + {% endfor %} + + |||{{answer_choices[label]}}' metadata: !TemplateMetadata choices_in_prompt: true languages: @@ -28,7 +28,7 @@ templates: name: pick_the_most_correct_option reference: '' a2: !Template - answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N' + answer_choices: 'A|||B|||C|||D' id: a2 jinja: '{{question}} @@ -36,12 +36,11 @@ templates: Options: {% for letter, t in zip(answer_choices, choices) %} - - {{letter}}: {{t}} {% endfor %} - A:|||{{answer_choices[label]}}' + Answer:|||{{answer_choices[label]}}' metadata: !TemplateMetadata choices_in_prompt: true languages: @@ -52,22 +51,23 @@ templates: name: qa_options reference: '' a3: !Template - answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N' + answer_choices: 'A|||B|||C|||D' id: a3 jinja: 'I am hesitating between 4 options to answer the following question, which option should I choose? + Question: {{question}} + Possibilities: {% for letter, t in zip(answer_choices, choices) %} - - {{letter}}: {{t}} {% endfor %} - A:|||{{answer_choices[label]}}' + Answer:|||{{answer_choices[label]}}' metadata: !TemplateMetadata choices_in_prompt: true languages: @@ -78,7 +78,7 @@ templates: name: i_am_hesitating reference: '' a4: !Template - answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N' + answer_choices: 'A|||B|||C|||D' id: a4 jinja: 'I gave my students this multiple choice question: {{question}} @@ -86,12 +86,11 @@ templates: Only one answer is correct among these 4 choices: {% for letter, t in zip(answer_choices, choices) %} - - {{letter}}: {{t}} {% endfor %} - A:|||{{answer_choices[label]}}' + Answer:|||{{answer_choices[label]}}' metadata: !TemplateMetadata choices_in_prompt: true languages: @@ -101,12 +100,20 @@ templates: original_task: true name: multiple_choice reference: '' - a6: !Template - answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N' + a5: !Template + answer_choices: 'A|||B|||C|||D' id: a6 - jinja: "Here's a problem to solve: {{question}}\n\nAmong the 4 following options,\ - \ which is the correct answer?\n{% for letter, t in zip(answer_choices, choices)\ - \ %}\n- {{letter}}: {{t}}\n {% endfor %}A:|||{{answer_choices[label]}}" + jinja: "Here's a problem to solve: {{question}} + + + Among the 4 following options, which is the correct answer? + + {% for letter, t in zip(answer_choices, choices) %} + {{letter}}: {{t}} + + {% endfor %} + + Answer:|||{{answer_choices[label]}}" metadata: !TemplateMetadata choices_in_prompt: true languages: diff --git a/elk/run.py b/elk/run.py index d9f1d9af..85e31993 100644 --- a/elk/run.py +++ b/elk/run.py @@ -125,8 +125,8 @@ def prepare_data( for ds_name, ds in self.datasets: key = select_split(ds, split_type) - split = ds[key].with_format("torch", device=device, dtype=torch.int16) + labels = assert_type(Tensor, split["label"]) val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) diff --git a/elk/training/sweep.py b/elk/training/sweep.py index a112cd63..49fa99e8 100644 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -89,6 +89,7 @@ def execute(self): if eval_dataset in train_datasets: continue + assert run.out_dir is not None eval = Eval( data=replace(run.data, model=model, datasets=(eval_dataset,)), source=run.out_dir, diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 9032c44e..21e9deb0 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -123,9 +123,7 @@ def infer_num_classes(label_feature: Any) -> int: """Return the number of classes in a `Dataset`. Returns: - The number of classes. - Raises: - ValueError: If the label column is not a `ClassLabel` or `Value('bool')`. + The number of classes, or -1 if it's unclear. """ if isinstance(label_feature, ClassLabel): # We piggyback on the ClassLabel feature type to get the number of classes From 5d60ebd064b627d37b0ab5a8f7b5f2cfeab00385 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sun, 30 Apr 2023 00:47:25 +0000 Subject: [PATCH 17/17] Update templates --- .../fever/v1.0/templates.yaml | 44 +++--- .../fever/v2.0/templates.yaml | 0 .../truthful_qa_binary/templates.yaml | 125 ++++++++++++++++ .../templates/binarization/templates.yaml | 133 ++++++++++++++++++ 4 files changed, 278 insertions(+), 24 deletions(-) rename elk/promptsource/templates/{ => EleutherAI}/fever/v1.0/templates.yaml (59%) rename elk/promptsource/templates/{ => EleutherAI}/fever/v2.0/templates.yaml (100%) create mode 100644 elk/promptsource/templates/EleutherAI/truthful_qa_binary/templates.yaml create mode 100644 elk/promptsource/templates/binarization/templates.yaml diff --git a/elk/promptsource/templates/fever/v1.0/templates.yaml b/elk/promptsource/templates/EleutherAI/fever/v1.0/templates.yaml similarity index 59% rename from elk/promptsource/templates/fever/v1.0/templates.yaml rename to elk/promptsource/templates/EleutherAI/fever/v1.0/templates.yaml index b66be449..e9e275dc 100644 --- a/elk/promptsource/templates/fever/v1.0/templates.yaml +++ b/elk/promptsource/templates/EleutherAI/fever/v1.0/templates.yaml @@ -1,16 +1,12 @@ dataset: fever subset: v1.0 -label_column: label -label_choices: - - REFUTES - - SUPPORTS templates: 0870481e-e5d1-43a1-821e-b11c6bfd2483: !Template - answer_choices: Yes|||No|||Not sure + answer_choices: No|||Yes id: 0870481e-e5d1-43a1-821e-b11c6bfd2483 - jinja: "{{claim}} Is this true?\n|||\n{% if label != \"\" %}\n{{\n{\"SUPPORTS\"\ - : \"Yes\",\n \"REFUTES\": \"No\",\n\"NOT ENOUGH INFO\": \"Not sure\"\n}[label]\n\ - }}\n{% endif %}" + jinja: "{{claim}} Is this true? + + |||{{answer_choices[label]}}" metadata: !TemplateMetadata choices_in_prompt: false languages: @@ -21,11 +17,11 @@ templates: name: cbqa_fever_postprompt reference: CBQA fever, prompt after claim 51c55af8-1996-4cb2-88a1-ca7ddb8f9e11: !Template - answer_choices: Yes|||No|||Not Sure + answer_choices: No|||Yes id: 51c55af8-1996-4cb2-88a1-ca7ddb8f9e11 - jinja: "I've heard that {{claim}} Is this correct? Yes, No or Not Sure?\n|||\n\ - {% if label != \"\" %}\n{{\n{\"SUPPORTS\": \"Yes\",\n \"REFUTES\": \"No\",\n\ - \"NOT ENOUGH INFO\": \"Not Sure\"\n}[label]\n}}\n{% endif %}" + jinja: "I've heard that {{claim}} Is this correct? Yes, No or Not Sure? + + |||{{answer_choices[label]}}" metadata: !TemplateMetadata choices_in_prompt: true languages: @@ -37,11 +33,11 @@ templates: reference: CBQA fever, like a conversation, with prompts surrounding claim, all class included. 6cc8f145-3fb4-43a9-aaf1-8c25dd6e2cdf: !Template - answer_choices: Yes|||No|||Unsure + answer_choices: No|||Yes id: 6cc8f145-3fb4-43a9-aaf1-8c25dd6e2cdf - jinja: "Is this statement correct? {{claim}} ||| \n{% if label != \"\" %}\n{{\n\ - {\"SUPPORTS\": \"Yes\",\n \"REFUTES\": \"No\",\n\"NOT ENOUGH INFO\": \"Unsure\"\ - \n}[label]\n}}\n{% endif %}" + jinja: "Is this statement correct? {{claim}} + + |||{{answer_choices[label]}}" metadata: !TemplateMetadata choices_in_prompt: false languages: @@ -52,11 +48,11 @@ templates: name: cbqa_fever_preprompt reference: Closed-book QA from only the claim, prompt before the content 948f41ab-e6bb-4de6-af3e-7f0b5d5f39a8: !Template - answer_choices: Yes|||No|||Maybe + answer_choices: No|||Yes id: 948f41ab-e6bb-4de6-af3e-7f0b5d5f39a8 - jinja: "\"{{claim}}\" Yes, no, maybe?\n|||\n{% if label != \"\" %}\n{{\n{\"SUPPORTS\"\ - : \"Yes\",\n \"REFUTES\": \"No\",\n\"NOT ENOUGH INFO\": \"Maybe\"\n}[label]\n\ - }}\n{% endif %}\n" + jinja: "\"{{claim}}\" Yes or no? + + |||{{answer_choices[label]}}" metadata: !TemplateMetadata choices_in_prompt: false languages: @@ -67,11 +63,11 @@ templates: name: cbqa_fever_short reference: CBQA fever, minimal b1d8f035-c3af-41a8-b0b8-1604f9dc00ff: !Template - answer_choices: Yes|||No|||Not Sure + answer_choices: No|||Yes id: b1d8f035-c3af-41a8-b0b8-1604f9dc00ff - jinja: "\"{{claim}}\", I have heard. Is this Correct? Yes, No or Not Sure?\n|||\n\ - {% if label != \"\" %}\n{{\n{\"SUPPORTS\": \"Yes\",\n \"REFUTES\": \"No\",\n\ - \"NOT ENOUGH INFO\": \"Not Sure\"\n}[label]\n}}\n{% endif %}" + jinja: "\"{{claim}}\", I have heard. Is this Correct? Yes or No? + + |||{{answer_choices[label]}}" metadata: !TemplateMetadata choices_in_prompt: true languages: diff --git a/elk/promptsource/templates/fever/v2.0/templates.yaml b/elk/promptsource/templates/EleutherAI/fever/v2.0/templates.yaml similarity index 100% rename from elk/promptsource/templates/fever/v2.0/templates.yaml rename to elk/promptsource/templates/EleutherAI/fever/v2.0/templates.yaml diff --git a/elk/promptsource/templates/EleutherAI/truthful_qa_binary/templates.yaml b/elk/promptsource/templates/EleutherAI/truthful_qa_binary/templates.yaml new file mode 100644 index 00000000..19e8c300 --- /dev/null +++ b/elk/promptsource/templates/EleutherAI/truthful_qa_binary/templates.yaml @@ -0,0 +1,125 @@ +# Adapted from the ARC-Easy templates +dataset: EleutherAI/truthful_qa_binary +templates: + a1: !Template + answer_choices: 'A|||B' + id: a1 + jinja: 'Pick the most correct option to answer the following question. + + + {{question}} + + + Options: + + {% for letter, t in zip(answer_choices, choices) %} + - {{letter}}: {{t}} + + {% endfor %} + + |||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: pick_the_most_correct_option + reference: '' + a2: !Template + answer_choices: 'A|||B' + id: a2 + jinja: '{{question}} + + + Options: + + {% for letter, t in zip(answer_choices, choices) %} + - {{letter}}: {{t}} + + {% endfor %} + + Answer:|||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: qa_options + reference: '' + a3: !Template + answer_choices: 'A|||B' + id: a3 + jinja: 'I am hesitating between 2 options to answer the following question, which + option should I choose? + + + Question: {{question}} + + + Possibilities: + + {% for letter, t in zip(answer_choices, choices) %} + - {{letter}}: {{t}} + + {% endfor %} + + Answer:|||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: i_am_hesitating + reference: '' + a4: !Template + answer_choices: 'A|||B' + id: a4 + jinja: 'I gave my students this multiple choice question: {{question}} + + + Only one answer is correct among these 2 choices: + + {% for letter, t in zip(answer_choices, choices) %} + - {{letter}}: {{t}} + + {% endfor %} + + Answer:|||{{answer_choices[label]}}' + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: multiple_choice + reference: '' + a5: !Template + answer_choices: 'A|||B' + id: a6 + jinja: "Here's a problem to solve: {{question}} + + + Among the 2 following options, which is the correct answer? + + {% for letter, t in zip(answer_choices, choices) %} + {{letter}}: {{t}} + + {% endfor %} + + Answer:|||{{answer_choices[label]}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: heres_a_problem + reference: '' diff --git a/elk/promptsource/templates/binarization/templates.yaml b/elk/promptsource/templates/binarization/templates.yaml new file mode 100644 index 00000000..608ff854 --- /dev/null +++ b/elk/promptsource/templates/binarization/templates.yaml @@ -0,0 +1,133 @@ +# Meta-template for binarizing a multiple choice dataset +dataset: binarization +templates: + a1: !Template + answer_choices: 'yes|||no' + id: a1 + jinja: "You will see a question followed by a response which may be incorrect. + You will be asked to figure out whether the answer is correct or not. + + + \"Q: {{question}} + + {% for letter, t in zip(['A', 'B', 'C', 'D'], choices) %} + - {{letter}}: {{t}} + + {% endfor %} + Response: {{['A', 'B', 'C', 'D'][response]}}, {{choices[response]}}\" + + + Is the response correct, yes or no?|||{{answer_choices[label]}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: yes_no + reference: '' + a2: !Template + answer_choices: 'correct|||incorrect' + id: a2 + jinja: "The following quote contains a question followed by a response which may be + incorrect. Your task is to determine whether the answer is correct or incorrect. + + + \"Q: {{question}} + + {% for letter, t in zip(['A', 'B', 'C', 'D'], choices) %} + - {{letter}}: {{t}} + + {% endfor %} + Response: {{['A', 'B', 'C', 'D'][response]}}, {{choices[response]}}\" + + + Was the response correct or incorrect?|||{{answer_choices[label]}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: correct_incorrect + reference: '' + a3: !Template + answer_choices: 'yes|||no' + id: a3 + jinja: "The following quote is an exam question followed by a student's response. + You will determine whether student chose the right answer. + + + \"Q: {{question}} + + {% for letter, t in zip(['A', 'B', 'C', 'D'], choices) %} + - {{letter}}: {{t}} + + {% endfor %} + Student's response: {{['A', 'B', 'C', 'D'][response]}}, {{choices[response]}}\" + + + Did the student choose the right answer, yes or no?|||{{answer_choices[label]}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: student_yes_no + reference: '' + a4: !Template + answer_choices: 'correct|||incorrect' + id: a4 + jinja: "The paragraph below is a quiz question followed by a student's response. + You will determine if the response was right or wrong. + + + \"Q: {{question}} + + {% for letter, t in zip(['A', 'B', 'C', 'D'], choices) %} + - {{letter}}: {{t}} + + {% endfor %} + Student's response: {{['A', 'B', 'C', 'D'][response]}}, {{choices[response]}}\" + + + Was the response right or wrong?|||{{answer_choices[label]}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: student_correct_incorrect + reference: '' + a5: !Template + answer_choices: 'true|||false' + id: a5 + jinja: "The following is an exam question followed by a student's response. + You will guess if they received full credit for the answer. + + + \"Q: {{question}} + + {% for letter, t in zip(['A', 'B', 'C', 'D'], choices) %} + - {{letter}}: {{t}} + + {% endfor %} + Student's response: {{['A', 'B', 'C', 'D'][response]}}, {{choices[response]}}\" + + + Did the student get full credit for this answer (yes or no)?|||{{answer_choices[label]}}" + metadata: !TemplateMetadata + choices_in_prompt: true + languages: + - en + metrics: + - Accuracy + original_task: true + name: student_full_credit + reference: ''