From 681698d08173d1bd673facb44b6b312ed94b33a6 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Thu, 9 Mar 2023 22:32:05 +0000 Subject: [PATCH 01/37] add multiple datasets support --- elk/extraction/extraction.py | 37 ++++++++++----- elk/extraction/prompt_dataset.py | 78 +++++++++++++++++++++++++++----- elk/training/train.py | 25 ++++++++-- 3 files changed, 114 insertions(+), 26 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index c6775f0e..1820c466 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,6 +1,6 @@ """Functions for extracting the hidden states of a model.""" -from .prompt_dataset import Prompt, PromptDataset, PromptConfig +from .prompt_dataset import Prompt, PromptDataset, PromptConfig, Interleaved_Datasets from ..utils import ( assert_type, infer_label_column, @@ -19,6 +19,7 @@ SplitDict, SplitInfo, Value, + interleave_datasets, ) from simple_parsing.helpers import field, Serializable from transformers import ( @@ -88,18 +89,27 @@ def extract_hiddens( if rank != 0: logging.disable(logging.CRITICAL) - prompt_ds = PromptDataset(cfg.prompts, rank, world_size, split) if rank == 0: - prompt_names = prompt_ds.prompter.all_template_names if cfg.prompts.num_variants >= 1: - print( - f"Using {cfg.prompts.num_variants} prompts per example: {prompt_names}" - ) + print(f"Using {cfg.prompts.num_variants} prompts per example") elif cfg.prompts.num_variants == -1: - print(f"Using all prompts per example: {prompt_names}") + print("Using all prompts per example") else: raise ValueError(f"Invalid prompt num_variants: {cfg.prompts.num_variants}") + prompt_datasets = [] + + # create a PromptDataset for each dataset in cfg.prompts + for dataset_index in range(len(cfg.prompts.datasets)): + dataset_name = cfg.prompts.datasets[dataset_index] + prompt_ds = PromptDataset(cfg.prompts, rank, world_size, split, dataset_index) + prompt_names = prompt_ds.prompter.all_template_names + print(f"Prompts for dataset {dataset_name}: {prompt_names}") + prompt_datasets.append(prompt_ds) + + # combine each PromptDataset together, interleaving them + interleaved_prompt_datasets = Interleaved_Datasets(prompt_datasets) + # AutoModel should do the right thing here in nearly all cases. We don't actually # care what head the model has, since we are just extracting hidden states. model = AutoModel.from_pretrained(cfg.model, torch_dtype="auto").to(device) @@ -114,7 +124,10 @@ def extract_hiddens( # TODO: Make this configurable or something # Token used to separate the question from the answer - num_choices = prompt_ds.num_classes + num_choices = prompt_datasets[0].num_classes + for i in range(1, len(prompt_datasets)): + assert prompt_datasets[i].num_classes == num_choices + sep_token = tokenizer.sep_token or "\n" if not tokenizer.pad_token: @@ -160,7 +173,8 @@ def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]: # Iterating over questions layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) - for prompts in prompt_ds: + + for prompts in interleaved_prompt_datasets: inputs = collate(prompts) hidden_dict = { f"hidden_{layer_idx}": torch.empty( @@ -228,7 +242,7 @@ def get_splits() -> SplitDict: { k: SplitInfo( name=k, - num_examples=min(limit, v.num_examples), + num_examples=min(limit, v.num_examples) * len(cfg.prompts.datasets), dataset_name=v.dataset_name, ) for k, v in base_splits.items() @@ -239,13 +253,14 @@ def get_splits() -> SplitDict: model_cfg = AutoConfig.from_pretrained(cfg.model) num_variants = cfg.prompts.num_variants - ds_name, _, config_name = cfg.prompts.dataset.partition(" ") + ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") info = get_dataset_config_info(ds_name, config_name or None) features = assert_type(Features, info.features) label_col = cfg.prompts.label_column or infer_label_column(features) splits = get_splits() + print("SPLITS: ", splits) layer_cols = { f"hidden_{layer}": Array3D( diff --git a/elk/extraction/prompt_dataset.py b/elk/extraction/prompt_dataset.py index 68efbce1..ef6819bf 100644 --- a/elk/extraction/prompt_dataset.py +++ b/elk/extraction/prompt_dataset.py @@ -2,7 +2,13 @@ from ..promptsource import DatasetTemplates from ..utils import assert_type, compute_class_balance, infer_label_column, undersample from dataclasses import dataclass -from datasets import DatasetDict, load_dataset +from datasets import ( + DatasetDict, + IterableDataset, + Dataset, + load_dataset, + concatenate_datasets, +) from numpy.typing import NDArray from random import Random from simple_parsing.helpers import field, Serializable @@ -44,7 +50,8 @@ class PromptConfig(Serializable): call to __getitem__. Use -1 to apply all available templates. Defaults to 1. """ - dataset: str = field(positional=True) + datasets: list[str] = field(positional=True) + # dataset2: str = field(positional=True) balance: bool = False label_column: Optional[str] = None max_examples: int = 0 @@ -53,6 +60,16 @@ class PromptConfig(Serializable): num_variants: int = 1 +def create_prompt_dataset( + cfg: PromptConfig, + rank: int = 0, + world_size: int = 1, + split: str = "validation", + dataset_index: int = 0, # which dataset in cfg.datasets to use +): + pass + + class PromptDataset(TorchDataset): """Wrapper for a HuggingFace dataset which generates prompts with `promptsource`. @@ -79,8 +96,12 @@ def __init__( rank: int = 0, world_size: int = 1, split: str = "validation", + dataset_index: int = 0, # which dataset in cfg.datasets to use ): - ds_name, _, config_name = cfg.dataset.partition(" ") + # super.__init__(self) + + dataset = cfg.datasets[dataset_index] + ds_name, _, config_name = dataset.partition(" ") self.num_shots = cfg.num_shots self.prompter = DatasetTemplates(ds_name, config_name or None) # type: ignore @@ -100,15 +121,16 @@ def __init__( # instantiations of PromptDataset (unless you set the seed to something else). # This allows you to just set split="train" and split="test" for any dataset # and not worry about train-test leakage. - split_name, *others = ds_dict.keys() - if not others: - print("Creating a 75/25 train-test split...") - # Don't shuffle now because we're going to shuffle later - ds_dict = ds_dict[split_name].train_test_split( - seed=cfg.seed, shuffle=False, stratify_by_column=cfg.label_column - ) - assert isinstance(ds_dict, DatasetDict) + # split_name, *others = ds_dict.keys() + # if not others: + # print("Creating a 75/25 train-test split...") + + # # Don't shuffle now because we're going to shuffle later + # ds_dict = ds_dict[split_name].train_test_split( + # seed=cfg.seed, shuffle=False, stratify_by_column=cfg.label_column + # ) + # assert isinstance(ds_dict, DatasetDict) # The 'active' split is the one that gets queried by __getitem__ self.active_split = ds_dict[split] @@ -225,3 +247,37 @@ def num_classes(self) -> int: # We piggyback on the ClassLabel feature type to get the number of classes return self.active_split.features[self.label_column].num_classes + + +class Interleaved_Datasets(TorchDataset): + def __init__( + self, + datasets: list[PromptDataset], + ): + """ + Interleave several (PromptDataset) datasets into a single dataset, + alternating between the datasets. + Only samples as many datapoints from each dataset as the smallest dataset. + Args: + datasets (`List[PromptDataset]`): + List of datasets to interleave. + """ + self.datasets = datasets + + if not datasets: + raise ValueError("Unable to interleave an empty list of datasets.") + + lengths = [len(dset) for dset in datasets] + self.min_dataset_length = min(lengths) + self.num_datasets = len(datasets) + + def __getitem__(self, index: int) -> list[Prompt]: + which_dataset = index % self.num_datasets + return self.datasets[which_dataset][int(index / self.num_datasets)] + + def __iter__(self): + return (self[i] for i in range(len(self))) + + def __len__(self): + """Get the number of predicates in the dataset.""" + return self.num_datasets * self.min_dataset_length diff --git a/elk/training/train.py b/elk/training/train.py index 0bec9427..43ada153 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -22,6 +22,7 @@ import random import torch import torch.multiprocessing as mp +from typing import Union @dataclass @@ -42,20 +43,22 @@ class RunConfig(Serializable): max_gpus: int = -1 normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" skip_baseline: bool = False + concatenate_layers: int = 0 + # if nonzero, appends the hidden states of the layer concatenate_layers before def train_reporter( cfg: RunConfig, dataset: DatasetDict, out_dir: Path, - layer: int, + layer: Union[int, list[int]], devices: list[str], world_size: int = 1, ): """Train a single reporter on a single layer.""" # Reproducibility - seed = cfg.net.seed + layer + seed = cfg.net.seed + layer if isinstance(layer, int) else layer[0] np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) @@ -72,9 +75,17 @@ def train_reporter( train_labels = cast(Tensor, train["label"]) val_labels = cast(Tensor, val["label"]) + # concatenate hidden states across layers if multiple layers are inputted + if isinstance(layer, list): + train_hiddens = torch.cat([train[f"hidden_{lay}"] for lay in layer], dim=1) + val_hiddens = torch.cat([val[f"hidden_{lay}"] for lay in layer], dim=1) + else: + train_hiddens = train[f"hidden_{layer}"] + val_hiddens = val[f"hidden_{layer}"] + train_h, val_h = normalize( - int16_to_float32(assert_type(Tensor, train[f"hidden_{layer}"])), - int16_to_float32(assert_type(Tensor, val[f"hidden_{layer}"])), + int16_to_float32(assert_type(Tensor, train_hiddens)), + int16_to_float32(assert_type(Tensor, val_hiddens)), method=cfg.normalization, ) x0, x1 = train_h.unbind(dim=-2) @@ -161,6 +172,12 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None): for feat in ds["train"].features if feat.startswith("hidden_") ] + + # concatenate hidden states from a previous layer, if told to + if cfg.concatenate_layers > 0: + for i in range(cfg.concatenate_layers, len(layers)): + layers[i] = [layers[i], layers[i] - cfg.concatenate_layers] + # Train reporters for each layer in parallel with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: fn = partial( From b864c7702c9c71db918a0009539a6a1b02971db1 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 10 Mar 2023 03:22:56 +0000 Subject: [PATCH 02/37] train_reporter works on a list of layers now --- elk/training/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/training/train.py b/elk/training/train.py index 678dd944..5da8fe0d 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -56,7 +56,7 @@ def train_reporter( devices: list[str], world_size: int = 1, ): - """Train a single reporter on a single layer.""" + """Train a single reporter on a single layer, or a list of layers.""" # Reproducibility seed = cfg.net.seed + layer if isinstance(layer, int) else layer[0] From 7d7d97cd8494bb038ce7b3c73351ffae29c049a3 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 10 Mar 2023 04:12:23 +0000 Subject: [PATCH 03/37] changing printed layer names --- elk/training/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/elk/training/train.py b/elk/training/train.py index 5da8fe0d..c3fd83fc 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -122,7 +122,8 @@ def train_reporter( lr_dir.mkdir(parents=True, exist_ok=True) reporter_dir.mkdir(parents=True, exist_ok=True) - stats = [layer, pseudo_auroc, train_loss, *val_result] + layer_name = layer if isinstance(layer, str) else " and ".join(layer) + stats = [layer_name, pseudo_auroc, train_loss, *val_result] if not cfg.skip_baseline: # repeat_interleave makes `num_variants` copies of each label, all within a From 4fe61e930dcb036242d155e0254d845a261ff6ed Mon Sep 17 00:00:00 2001 From: Benjamin Date: Sat, 11 Mar 2023 17:27:23 +0000 Subject: [PATCH 04/37] fixed concatenation bug --- elk/training/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/elk/training/train.py b/elk/training/train.py index c3fd83fc..5a749d4d 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -78,8 +78,8 @@ def train_reporter( # concatenate hidden states across layers if multiple layers are inputted if isinstance(layer, list): - train_hiddens = torch.cat([train[f"hidden_{lay}"] for lay in layer], dim=1) - val_hiddens = torch.cat([val[f"hidden_{lay}"] for lay in layer], dim=1) + train_hiddens = torch.cat([train[f"hidden_{lay}"] for lay in layer], dim=-1) + val_hiddens = torch.cat([val[f"hidden_{lay}"] for lay in layer], dim=-1) else: train_hiddens = train[f"hidden_{layer}"] val_hiddens = val[f"hidden_{layer}"] @@ -122,7 +122,7 @@ def train_reporter( lr_dir.mkdir(parents=True, exist_ok=True) reporter_dir.mkdir(parents=True, exist_ok=True) - layer_name = layer if isinstance(layer, str) else " and ".join(layer) + layer_name = layer if isinstance(layer, int) else max(layer) stats = [layer_name, pseudo_auroc, train_loss, *val_result] if not cfg.skip_baseline: From fe61d67cc335417ce0bda639690563208b66bb07 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 13 Mar 2023 22:52:20 +0000 Subject: [PATCH 05/37] minor edits --- elk/extraction/prompt_dataset.py | 18 +++++++++--------- elk/training/train.py | 9 +++++---- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/elk/extraction/prompt_dataset.py b/elk/extraction/prompt_dataset.py index c2d62cc9..13b76239 100644 --- a/elk/extraction/prompt_dataset.py +++ b/elk/extraction/prompt_dataset.py @@ -120,15 +120,15 @@ def __init__( # This allows you to just set split="train" and split="test" for any dataset # and not worry about train-test leakage. - # split_name, *others = ds_dict.keys() - # if not others: - # print("Creating a 75/25 train-test split...") + split_name, *others = ds_dict.keys() + if not others: + print("Creating a 75/25 train-test split...") - # # Don't shuffle now because we're going to shuffle later - # ds_dict = ds_dict[split_name].train_test_split( - # seed=cfg.seed, shuffle=False, stratify_by_column=cfg.label_column - # ) - # assert isinstance(ds_dict, DatasetDict) + # Don't shuffle now because we're going to shuffle later + ds_dict = ds_dict[split_name].train_test_split( + seed=cfg.seed, shuffle=False, stratify_by_column=cfg.label_column + ) + assert isinstance(ds_dict, DatasetDict) # The 'active' split is the one that gets queried by __getitem__ self.active_split = ds_dict[split] @@ -254,7 +254,7 @@ def num_classes(self) -> int: return self.active_split.features[self.label_column].num_classes -class Interleaved_Datasets(TorchDataset): +class InterleavedDatasets(TorchDataset): def __init__( self, datasets: list[PromptDataset], diff --git a/elk/training/train.py b/elk/training/train.py index 5a749d4d..f48413a7 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -52,14 +52,14 @@ def train_reporter( cfg: RunConfig, dataset: DatasetDict, out_dir: Path, - layer: Union[int, list[int]], + layer: list[int], devices: list[str], world_size: int = 1, ): """Train a single reporter on a single layer, or a list of layers.""" # Reproducibility - seed = cfg.net.seed + layer if isinstance(layer, int) else layer[0] + seed = cfg.net.seed + layer[0] np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) @@ -180,8 +180,9 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None): if not cfg.skip_baseline: cols += ["lr_auroc", "lr_acc"] + # Create subsets of layers to train reporters on layers = [ - int(feat[len("hidden_") :]) + [int(feat[len("hidden_") :])] for feat in ds["train"].features if feat.startswith("hidden_") ] @@ -189,7 +190,7 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None): # concatenate hidden states from a previous layer, if told to if cfg.concatenate_layers > 0: for i in range(cfg.concatenate_layers, len(layers)): - layers[i] = [layers[i], layers[i] - cfg.concatenate_layers] + layers[i] = layers[i] + [layers[i][0] - cfg.concatenate_layers] # Train reporters for each layer in parallel with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: From 74da87806ea431a6044ad2dc28ebc4a2433071c4 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 13 Mar 2023 23:11:09 +0000 Subject: [PATCH 06/37] fixed pyright issues --- elk/extraction/extraction.py | 7 +++---- elk/extraction/prompt_dataset.py | 4 ++-- elk/training/train.py | 14 +++++++------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 4553a570..e975a915 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,6 +1,6 @@ """Functions for extracting the hidden states of a model.""" -from .prompt_dataset import Prompt, PromptDataset, PromptConfig, Interleaved_Datasets +from .prompt_dataset import Prompt, PromptDataset, PromptConfig, InterleavedDatasets from ..utils import ( assert_type, infer_label_column, @@ -19,7 +19,6 @@ SplitDict, SplitInfo, Value, - interleave_datasets, ) from simple_parsing.helpers import field, Serializable from transformers import ( @@ -108,7 +107,7 @@ def extract_hiddens( prompt_datasets.append(prompt_ds) # combine each PromptDataset together, interleaving them - interleaved_prompt_datasets = Interleaved_Datasets(prompt_datasets) + interleaved_prompt_datasets = InterleavedDatasets(prompt_datasets) # AutoModel should do the right thing here in nearly all cases. We don't actually # care what head the model has, since we are just extracting hidden states. @@ -178,7 +177,7 @@ def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]: inputs = collate(prompts) hidden_dict = { f"hidden_{layer_idx}": torch.empty( - prompt_ds.num_variants, + prompt_datasets[0].num_variants, num_choices, model.config.hidden_size, device=device, diff --git a/elk/extraction/prompt_dataset.py b/elk/extraction/prompt_dataset.py index 13b76239..c4b667b3 100644 --- a/elk/extraction/prompt_dataset.py +++ b/elk/extraction/prompt_dataset.py @@ -156,8 +156,8 @@ def __init__( # Sanity check to prevent train-test leakage via few-shot prompts if "train" not in ds_dict: raise ValueError( - f"Dataset {cfg.dataset} has no train split, so we can't create " - "few-shot prompts" + f"Dataset {cfg.datasets[dataset_index]} has no train split, " + "so we can't create few-shot prompts" ) self.fewshot_strata = [ diff --git a/elk/training/train.py b/elk/training/train.py index f48413a7..6489b9e6 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -77,12 +77,12 @@ def train_reporter( val_labels = cast(Tensor, val["label"]) # concatenate hidden states across layers if multiple layers are inputted - if isinstance(layer, list): - train_hiddens = torch.cat([train[f"hidden_{lay}"] for lay in layer], dim=-1) - val_hiddens = torch.cat([val[f"hidden_{lay}"] for lay in layer], dim=-1) - else: - train_hiddens = train[f"hidden_{layer}"] - val_hiddens = val[f"hidden_{layer}"] + train_hiddens = torch.cat( + [cast(Tensor, train[f"hidden_{lay}"]) for lay in layer], dim=-1 + ) + val_hiddens = torch.cat( + [cast(Tensor, val[f"hidden_{lay}"]) for lay in layer], dim=-1 + ) train_h, val_h = normalize( int16_to_float32(assert_type(Tensor, train_hiddens)), @@ -99,7 +99,7 @@ def train_reporter( ) if pseudo_auroc > 0.6: warnings.warn( - f"The pseudo-labels at layer {layer} are linearly separable with " + f"The pseudo-labels at layers {layer} are linearly separable with " f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the " f"algorithm will not converge to a good solution." ) From fe94c22c83579e093b610c317ede48524403b211 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Mon, 20 Mar 2023 18:58:36 +0000 Subject: [PATCH 07/37] Fix tests --- tests/distilgpt2_copa_cfg.yaml | 3 ++- tests/distilgpt2_dbpedia_cfg.yaml | 3 ++- tests/test_prompt_dataset.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/distilgpt2_copa_cfg.yaml b/tests/distilgpt2_copa_cfg.yaml index bd17194e..cff7ec2f 100644 --- a/tests/distilgpt2_copa_cfg.yaml +++ b/tests/distilgpt2_copa_cfg.yaml @@ -2,7 +2,8 @@ layers: [] model: distilgpt2 prompts: balance: true - dataset: super_glue copa + datasets: + - "super_glue copa" label_column: null max_examples: - 2 diff --git a/tests/distilgpt2_dbpedia_cfg.yaml b/tests/distilgpt2_dbpedia_cfg.yaml index ac109642..629ab302 100644 --- a/tests/distilgpt2_dbpedia_cfg.yaml +++ b/tests/distilgpt2_dbpedia_cfg.yaml @@ -2,7 +2,8 @@ layers: [] model: distilgpt2 prompts: balance: true - dataset: dbpedia_14 + datasets: + - "dbpedia_14" label_column: null max_examples: - 3 diff --git a/tests/test_prompt_dataset.py b/tests/test_prompt_dataset.py index 40235217..e87158d8 100644 --- a/tests/test_prompt_dataset.py +++ b/tests/test_prompt_dataset.py @@ -7,7 +7,7 @@ def test_prompt_dataset_getitem_boolq(): def test_prompt_dataset_getitem(cfg: ExtractionConfig, split: str): prompt_ds = PromptDataset(cfg.prompts, rank=0, world_size=1, split=split) - ds_name, _, config_name = cfg.prompts.dataset.partition(" ") + ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") prompter = DatasetTemplates(ds_name, config_name or None) assert len(prompt_ds) == cfg.prompts.max_examples[-1] From bba24d8efe2b5d38a407b12169508155da885d0b Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 22 Mar 2023 07:35:16 +0000 Subject: [PATCH 08/37] Now working sorta --- elk/__init__.py | 2 +- elk/__main__.py | 2 +- elk/evaluation/__init__.py | 0 elk/evaluation/evaluate.py | 19 +- elk/extraction/__init__.py | 3 +- elk/extraction/balanced_sampler.py | 83 +++++++++ elk/extraction/extraction.py | 145 ++++++--------- elk/extraction/prompt_dataset.py | 290 ----------------------------- elk/extraction/prompt_loading.py | 209 +++++++++++++++++++++ elk/utils/__init__.py | 3 - elk/utils/data_utils.py | 64 +------ pyproject.toml | 4 +- tests/test_prompt_dataset.py | 80 ++++---- tests/test_samplers.py | 55 ++++++ 14 files changed, 462 insertions(+), 497 deletions(-) create mode 100644 elk/evaluation/__init__.py create mode 100644 elk/extraction/balanced_sampler.py delete mode 100644 elk/extraction/prompt_dataset.py create mode 100644 elk/extraction/prompt_loading.py create mode 100644 tests/test_samplers.py diff --git a/elk/__init__.py b/elk/__init__.py index b47551df..5a779b70 100644 --- a/elk/__init__.py +++ b/elk/__init__.py @@ -1 +1 @@ -from .extraction import extract_hiddens, ExtractionConfig, PromptDataset +from .extraction import extract_hiddens, ExtractionConfig diff --git a/elk/__main__.py b/elk/__main__.py index 2df1c2f3..ca68a57e 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -1,7 +1,7 @@ """Main entry point for `elk`.""" from .extraction import extract, ExtractionConfig -from elk.evaluation.evaluate import EvaluateConfig, evaluate_reporters +from .evaluation.evaluate import EvaluateConfig, evaluate_reporters from .training import RunConfig from .training.train import train from pathlib import Path diff --git a/elk/evaluation/__init__.py b/elk/evaluation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 9d06efaf..a8be2da8 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,21 +1,16 @@ -import csv -import os -import pickle +from ..training.preprocessing import normalize from dataclasses import dataclass +from datasets import DatasetDict from functools import partial -from hashlib import md5 from pathlib import Path -from typing import List, Literal, Optional, cast - -import torch -import torch.multiprocessing as mp -import yaml from simple_parsing.helpers import Serializable, field from torch import Tensor from tqdm.auto import tqdm - -from datasets import DatasetDict -from elk.training.preprocessing import normalize +from typing import Literal, Optional, cast +import csv +import os +import torch +import torch.multiprocessing as mp from ..extraction import ExtractionConfig, extract from ..files import elk_reporter_dir, memorably_named_dir diff --git a/elk/extraction/__init__.py b/elk/extraction/__init__.py index a99c7d9d..1d22e800 100644 --- a/elk/extraction/__init__.py +++ b/elk/extraction/__init__.py @@ -1,3 +1,4 @@ +from .balanced_sampler import BalancedBatchSampler, BalancedSampler from .extraction import ExtractionConfig, extract_hiddens, extract from .generator import _GeneratorConfig, _GeneratorBuilder -from .prompt_dataset import PromptDataset, PromptConfig +from .prompt_loading import PromptConfig, load_prompts diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py new file mode 100644 index 00000000..c4da69b5 --- /dev/null +++ b/elk/extraction/balanced_sampler.py @@ -0,0 +1,83 @@ +from ..utils import infer_label_column +from collections import Counter +from datasets import IterableDataset +from itertools import cycle +from torch.utils.data import IterableDataset as TorchIterableDataset +from typing import Iterator, Optional +import numpy as np + + +class BalancedSampler(TorchIterableDataset): + """ + Approximately balances a binary classification dataset in a streaming fashion. + Written mostly by GPT-4. + + Args: + dataset (IterableDataset): The HuggingFace IterableDataset to balance. + label_col (Optional[str], optional): The name of the column containing the + binary label. If not provided, the label column will be inferred from + the dataset features. Defaults to None. + buffer_size (int, optional): The total buffer size to use for balancing the + dataset. This value should be divisible by 2, as it will be equally + divided between the two binary label values (0 and 1). Defaults to 1000. + """ + + def __init__(self, dataset: IterableDataset): + self.dataset = dataset + self.class_counts = np.zeros(2) + + def __iter__(self): + for sample in self.dataset: + label = sample["label"] + + # Update class counts + self.class_counts[label] += 1 + current_balance = self.class_counts / self.class_counts.sum() + + # Check if the sample should be dropped + majority_class = np.argmax(current_balance) + if label == majority_class: + # Solution of n * p * q / [n * (1 - p) + n * p * q] = 0.5 for q + keep_prob = 1 / current_balance[majority_class] - 1 + if np.random.rand() < 1 - keep_prob: + continue + + yield sample + + +class BalancedBatchSampler: + """Yields precisely balanced batches from a binary classification dataset. + + Written by a human being because GPT-4 couldn't figure out how to do it. + """ + + def __init__( + self, + dataset: IterableDataset, + label_col: Optional[str] = None, + batch_size: int = 32, + ): + self.batch_size = batch_size + self.dataset = dataset + self.label_col = label_col or infer_label_column(dataset.features) + + def __iter__(self) -> Iterator[list[dict]]: + batch = [] + + max_count = self.batch_size // 2 + label_counts = Counter() + + # Infinite loop! + for sample in cycle(self.dataset): + label = sample[self.label_col] + if label_counts[label] >= max_count: + continue + + batch.append(sample) + label_counts[label] += 1 + + if len(batch) == self.batch_size: + yield batch + + batch = [] + label_counts.clear() diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 3f9d0e67..550b9c2e 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,6 +1,6 @@ """Functions for extracting the hidden states of a model.""" -from .prompt_dataset import Prompt, PromptDataset, PromptConfig, InterleavedDatasets +from .prompt_loading import load_prompts, PromptConfig from ..utils import ( assert_type, float32_to_int16, @@ -8,10 +8,12 @@ select_train_val_splits, select_usable_devices, ) +from .balanced_sampler import BalancedSampler from .generator import _GeneratorBuilder from dataclasses import dataclass, InitVar from datasets import ( Array3D, + ClassLabel, DatasetDict, Features, get_dataset_config_info, @@ -25,11 +27,11 @@ AutoConfig, AutoModel, AutoTokenizer, - BatchEncoding, PreTrainedModel, ) from typing import Iterable, Literal, Union import logging +import os import torch @@ -73,71 +75,44 @@ def extract_hiddens( cfg: ExtractionConfig, *, device: Union[str, torch.device] = "cpu", + split_type: Literal["train", "val"] = "train", rank: int = 0, - split: str, world_size: int = 1, ) -> Iterable[dict]: """Run inference on a model with a set of prompts, yielding the hidden states. This is a lightweight, functional version of the `Extractor` API. """ + os.environ["TOKENIZERS_PARALLELISM"] = "false" # Silence datasets logging messages from all but the first process if rank != 0: logging.disable(logging.CRITICAL) - - if rank == 0: - if cfg.prompts.num_variants >= 1: - print(f"Using {cfg.prompts.num_variants} prompts per example") - elif cfg.prompts.num_variants == -1: - print("Using all prompts per example") - else: - raise ValueError(f"Invalid prompt num_variants: {cfg.prompts.num_variants}") - - prompt_datasets = [] - - # create a PromptDataset for each dataset in cfg.prompts - for dataset_index in range(len(cfg.prompts.datasets)): - dataset_name = cfg.prompts.datasets[dataset_index] - prompt_ds = PromptDataset(cfg.prompts, rank, world_size, split, dataset_index) - prompt_names = prompt_ds.prompter.all_template_names - print(f"Prompts for dataset {dataset_name}: {prompt_names}") - prompt_datasets.append(prompt_ds) - - # combine each PromptDataset together, interleaving them - interleaved_prompt_datasets = InterleavedDatasets(prompt_datasets) + if rank == 0 and cfg.prompts.num_variants >= 1: + print(f"Using {cfg.prompts.num_variants} prompts per example") + + limits = cfg.prompts.max_examples + prompt_ds = load_prompts( + *cfg.prompts.datasets, + max_examples=limits[0 if split_type == "train" else 1], + split_type=split_type, + rank=rank, + world_size=world_size, + ) + num_variants = prompt_ds.features["prompts"].length # AutoModel should do the right thing here in nearly all cases. We don't actually # care what head the model has, since we are just extracting hidden states. model = AutoModel.from_pretrained(cfg.model, torch_dtype="auto").to(device) # TODO: Maybe also make this configurable? # We want to make sure the answer is never truncated - tokenizer = AutoTokenizer.from_pretrained(cfg.model, truncation_side="left") - - # TODO: test whether using sep_token is important, but this seems low priority - # sep_token = tokenizer.sep_token or "\n" - if not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.eos_token - + tokenizer = AutoTokenizer.from_pretrained( + cfg.model, truncation_side="left", verbose=False + ) is_enc_dec = model.config.is_encoder_decoder - def tokenize(prompt: Prompt, idx: int, **kwargs): - return tokenizer( - ([prompt.to_string(idx)]), - padding=True, - return_tensors="pt", - truncation=True, - **kwargs, - ).to(device) - - # This function returns the flattened questions and answers. After inference we - # need to reshape the results. - def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]: - return [[tokenize(prompt, i) for i in range(2)] for prompt in prompts] - # If this is an encoder-decoder model we don't need to run the decoder at all. - # Just strip it off, making the problem - # equivalent to a regular encoder-only model. + # Just strip it off, making the problem equivalent to a regular encoder-only model. if is_enc_dec: # This isn't actually *guaranteed* by HF, but it's true for all existing models if not hasattr(model, "get_encoder") or not callable(model.get_encoder): @@ -149,12 +124,12 @@ def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]: # Iterating over questions layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) + # print(f"Using {prompt_ds} variants for each dataset") - for prompts in interleaved_prompt_datasets: - inputs = collate(prompts) + for example in BalancedSampler(prompt_ds): hidden_dict = { f"hidden_{layer_idx}": torch.empty( - prompt_datasets[0].num_variants, + num_variants, 2, # contrast pair model.config.hidden_size, device=device, @@ -162,25 +137,23 @@ def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]: ) for layer_idx in layer_indices } - variant_ids = [prompt.template.name for prompt in prompts] - # decode so that we know exactly what the input was - text_inputs = [ - [ - tokenizer.decode( - assert_type(torch.Tensor, variant_inputs[0].input_ids)[0] - ), - tokenizer.decode( - assert_type(torch.Tensor, variant_inputs[1].input_ids)[0] - ), - ] - for variant_inputs in inputs - ] + text_inputs = [] # Iterate over variants - for i, variant_inputs in enumerate(inputs): + for i, record in enumerate(example["prompts"]): + variant_inputs = [] + # Iterate over answers - for j, inpt in enumerate(variant_inputs): - outputs = model(**inpt, output_hidden_states=True) + for j in range(2): + text = record["text"][j] + variant_inputs.append(text) + + inputs = tokenizer( + text, + return_tensors="pt", + truncation=True, + ).to(device) + outputs = model(**inputs, output_hidden_states=True) hiddens = ( outputs.get("decoder_hidden_states") or outputs["hidden_states"] @@ -204,10 +177,11 @@ def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]: for layer_idx, hidden in zip(layer_indices, hiddens): hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden) - assert all([prompts[0].label == prompt.label for prompt in prompts]) + text_inputs.append(variant_inputs) + yield dict( - label=prompts[0].label, - variant_ids=variant_ids, + label=example["label"], + variant_ids=example["template_names"], text_inputs=text_inputs, **hidden_dict, ) @@ -223,8 +197,12 @@ def extract(cfg: ExtractionConfig, max_gpus: int = -1) -> DatasetDict: def get_splits() -> SplitDict: available_splits = assert_type(SplitDict, info.splits) - splits = select_train_val_splits(available_splits) - print(f"Using '{splits[0]}' for training and '{splits[1]}' for validation") + train_name, val_name = select_train_val_splits(available_splits) + print(f"Using '{train_name}' for training and '{val_name}' for validation") + + out_splits = SplitDict( + train=available_splits[train_name], val=available_splits[val_name] + ) # Empty list means no limit limit_list = cfg.prompts.max_examples @@ -233,33 +211,27 @@ def get_splits() -> SplitDict: # Broadcast the limit to all splits if len(limit_list) == 1: - limit_list *= len(splits) + limit_list *= len(out_splits) - limit = {k: v for k, v in zip(splits, limit_list)} return SplitDict( { k: SplitInfo( name=k, - num_examples=min(limit[k], v.num_examples) - * len(cfg.prompts.datasets), + num_examples=min(limit, v.num_examples) * len(cfg.prompts.datasets), dataset_name=v.dataset_name, ) - for k, v in available_splits.items() - if k in splits + for limit, (k, v) in zip(limit_list, available_splits.items()) }, dataset_name=available_splits.dataset_name, ) + devices = select_usable_devices(max_gpus) model_cfg = AutoConfig.from_pretrained(cfg.model) num_variants = cfg.prompts.num_variants + ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") info = get_dataset_config_info(ds_name, config_name or None) - features = assert_type(Features, info.features) - label_col = cfg.prompts.label_column or infer_label_column(features) - - splits = get_splits() - layer_cols = { f"hidden_{layer}": Array3D( dtype="int16", @@ -272,7 +244,7 @@ def get_splits() -> SplitDict: Value(dtype="string"), length=num_variants, ), - "label": features[label_col], + "label": ClassLabel(names=["neg", "pos"]), "text_inputs": Sequence( Sequence( Value(dtype="string"), @@ -281,7 +253,7 @@ def get_splits() -> SplitDict: length=num_variants, ), } - devices = select_usable_devices(max_gpus) + builders = { split_name: _GeneratorBuilder( cache_dir=None, @@ -293,15 +265,16 @@ def get_splits() -> SplitDict: cfg=[cfg] * len(devices), device=devices, rank=list(range(len(devices))), - split=[split_name] * len(devices), + split_type=[split_name] * len(devices), world_size=[len(devices)] * len(devices), ), ) - for (split_name, split_info) in splits.items() + for (split_name, split_info) in get_splits().items() } ds = dict() for split, builder in builders.items(): builder.download_and_prepare(num_proc=len(devices)) ds[split] = builder.as_dataset(split=split) + return DatasetDict(ds) diff --git a/elk/extraction/prompt_dataset.py b/elk/extraction/prompt_dataset.py deleted file mode 100644 index 4cea3103..00000000 --- a/elk/extraction/prompt_dataset.py +++ /dev/null @@ -1,290 +0,0 @@ -from ..math_util import stochastic_round_constrained -from ..promptsource import DatasetTemplates, Template -from ..utils import ( - apply_template, - assert_type, - binarize, - compute_class_balance, - infer_label_column, - infer_num_classes, - select_train_val_splits, - undersample, -) -from dataclasses import dataclass -from datasets import DatasetDict, load_dataset, ClassLabel, Value -from numpy.typing import NDArray -from random import Random -from simple_parsing.helpers import field, Serializable -from torch.utils.data import Dataset as TorchDataset -from typing import Optional, Any -import numpy as np - - -@dataclass -class Prompt: - """A prompt for a single example in a dataset""" - - prefix: str - template: Template - example: dict[str, Any] - label: int - label_column: str - - def to_string(self, answer_idx: int) -> str: - """Return the prompt as a string, with the answer at `answer_idx`.""" - fake_example = self.example.copy() - fake_example[self.label_column] = answer_idx - return self.prefix + apply_template(self.template, fake_example) - - -@dataclass -class PromptConfig(Serializable): - """ - Args: - dataset: Space-delimited name of the HuggingFace dataset to use, e.g. - `"super_glue boolq"` or `"imdb"`. - balance: Whether to force class balance in the dataset using undersampling. - data_dir: The directory to use for caching the dataset. Defaults to - `~/.cache/huggingface/datasets`. - label_column: The column containing the labels. By default, we infer this from - the datatypes of the columns in the dataset; if there is only one column - with a `ClassLabel` datatype, we use that. - num_classes: The number of classes in the dataset. By default, we infer this - from the datatypes of the columns in the dataset; if there is only one - column with a `ClassLabel` datatype, we use the number of classes in that - column. - max_examples: The maximum number of examples to use from the val dataset. - If a single number, use at most that many examples for each split. If a list - of length 2, use the first element for the train split and the second for - the val split. If empty, use all examples. Defaults to empty. - num_shots: The number of examples to use in few-shot prompts. If zero, prompts - are zero-shot. Defaults to 0. - seed: The seed to use for prompt randomization. Defaults to 42. - num_variants: The number of prompt templates to apply to each predicate upon - call to __getitem__. Use -1 to apply all available templates. Defaults to 1. - """ - - datasets: list[str] = field(positional=True) - balance: bool = False - data_dir: Optional[str] = None - label_column: Optional[str] = None - num_classes: Optional[int] = None - max_examples: list[int] = field(default_factory=lambda: [750, 250]) - num_shots: int = 0 - num_variants: int = -1 - seed: int = 42 - - def __post_init__(self): - if len(self.max_examples) > 2: - raise ValueError( - "max_examples should be a list of length 0, 1, or 2," - f"but got {len(self.max_examples)}" - ) - - -class PromptDataset(TorchDataset): - """Wrapper for a HuggingFace dataset which generates prompts with `promptsource`. - - Usually `promptsource` has multiple prompt templates for a given dataset. We sample - `num_variants` of these templates and apply them to each example in the dataset, up - to `max_examples` examples. If `num_shots` is greater than zero, we sample that - many examples from the dataset and use them to generate a prefix for the prompt. - - Example: - >>> prompts = PromptDataset("super_glue", "boolq", split="train") - >>> prompt = prompts[0] - >>> prompt.to_string(0) - "Henry Mills (Once Upon a Time) -- Henry Daniel Mills is a fictional character... - """ - - def __init__( - self, - cfg: PromptConfig, - rank: int = 0, - world_size: int = 1, - split: str = "validation", - dataset_index: int = 0, # which dataset in cfg.datasets to use - ): - dataset = cfg.datasets[dataset_index] - ds_name, _, config_name = dataset.partition(" ") - - self.num_shots = cfg.num_shots - self.prompter = DatasetTemplates(ds_name, config_name or None) # type: ignore - self.rng = Random(cfg.seed) - self.num_variants = ( - cfg.num_variants if cfg.num_variants > 0 else len(self.prompter.templates) - ) - - ds_dict = assert_type( - DatasetDict, # TODO: Should we support IterableDataset? - load_dataset(ds_name, config_name or None, data_dir=cfg.data_dir), - ) - - # By default we use the existing train-validation/test split in the dataset. - # If it doesn't exist, we create our own 75/25 train-test split. Crucially, - # because the RNG is always seeded, this split will be the same for independent - # instantiations of PromptDataset (unless you set the seed to something else). - # This allows you to just set split="train" and split="test" for any dataset - # and not worry about train-test leakage. - - split_name, *others = ds_dict.keys() - if not others: - print("Creating a 75/25 train-test split...") - - # Don't shuffle now because we're going to shuffle later - ds_dict = ds_dict[split_name].train_test_split( - seed=cfg.seed, shuffle=False, stratify_by_column=cfg.label_column - ) - assert isinstance(ds_dict, DatasetDict) - - # The 'active' split is the one that gets queried by __getitem__ - self.active_split = ds_dict[split] - label_col = cfg.label_column or infer_label_column(self.active_split.features) - self.label_column = label_col - self.num_classes = cfg.num_classes or infer_num_classes( - self.active_split.features[label_col] - ) - - # Enforce class balance if needed - if cfg.balance: - self.active_split = undersample( - self.active_split, self.rng, self.num_classes, label_col - ) - self.class_fracs = np.ones(self.num_classes) / self.num_classes - else: - class_sizes = compute_class_balance( - self.active_split, self.num_classes, label_col - ) - self.class_fracs: NDArray[np.floating] = class_sizes / class_sizes.sum() - - # We use stratified sampling to create few-shot prompts that are as balanced as - # possible. If needed, create the strata now so that we can use them later. - if cfg.num_shots > 0: - # Sanity check that we can fit an example from every class in the prompt - if self.num_classes > cfg.num_shots: - raise ValueError( - f"Few-shot prompts should contain at least one example from each " - f"class; got {cfg.num_shots} examples, {self.num_classes} classes" - ) - - train_split = select_train_val_splits(ds_dict)[0] - - self.fewshot_strata = [ - ds_dict[train_split].filter(lambda ex: ex[label_col] == i) - for i in range(self.num_classes) - ] - else: - self.fewshot_strata = [] - - # Now shuffle the active split and truncate it if needed - self.active_split = self.active_split.shuffle(seed=cfg.seed) - - if cfg.max_examples: - max_examples = ( - cfg.max_examples[0] - if split == "train" or len(cfg.max_examples) == 1 - else cfg.max_examples[1] - ) - if 0 < max_examples < len(self.active_split): - self.active_split = self.active_split.select(range(max_examples)) - - # Shard if needed - if world_size > 1: - self.active_split = self.active_split.shard(world_size, rank) - - def __getitem__(self, index: int) -> list[Prompt]: - """Get a list of prompts for a given predicate""" - # get self.num_variants unique prompts from the template pool - template_names = ( - self.rng.sample(list(self.prompter.templates), self.num_variants) - if self.num_variants < len(self.prompter.templates) - else list(self.prompter.templates) - ) - - example = self.active_split[index] - true_label = example[self.label_column] - new_label = self.rng.choice([0, 1]) if self.num_classes > 2 else None - - prompts = [] - for template_name in template_names: - template = self.prompter.templates[template_name] - - if self.num_shots > 0: - # Use stratified sampling to get `num_shots` examples from train set. - # If `num_shots` is not divisible by the number of classes, stochastic - # rounding is used to determine the number of examples per class. - example_counts = stochastic_round_constrained( - list(self.class_fracs * self.num_shots), self.rng - ) - examples = [] - - for count, stratum in zip(example_counts, self.fewshot_strata): - indices = self.rng.sample(range(len(stratum)), count) - - for idx in indices: - examples.append(apply_template(template, stratum[idx])) - - self.rng.shuffle(examples) - few_shot_prefix = "\n\n".join(examples) + "\n\n" - else: - few_shot_prefix = "" - - if self.num_classes > 2: - # remove all but the true answer and one random other answer - variant_template, variant_label = binarize( - template, true_label, assert_type(int, new_label), self.rng - ) - else: - variant_template, variant_label = template, true_label - - prompts.append( - Prompt( - template=variant_template, - example=example, - label=variant_label, - label_column=self.label_column, - prefix=few_shot_prefix, - ) - ) - return prompts - - def __iter__(self): - return (self[i] for i in range(len(self.active_split))) - - def __len__(self): - """Get the number of predicates in the dataset.""" - return len(self.active_split) - - -class InterleavedDatasets(TorchDataset): - def __init__( - self, - datasets: list[PromptDataset], - ): - """ - Interleave several (PromptDataset) datasets into a single dataset, - alternating between the datasets. - Only samples as many datapoints from each dataset as the smallest dataset. - Args: - datasets (`List[PromptDataset]`): - List of datasets to interleave. - """ - self.datasets = datasets - - if not datasets: - raise ValueError("Unable to interleave an empty list of datasets.") - - lengths = [len(dset) for dset in datasets] - self.min_dataset_length = min(lengths) - self.num_datasets = len(datasets) - - def __getitem__(self, index: int) -> list[Prompt]: - which_dataset = index % self.num_datasets - return self.datasets[which_dataset][int(index / self.num_datasets)] - - def __iter__(self): - return (self[i] for i in range(len(self))) - - def __len__(self): - """Get the number of predicates in the dataset.""" - return self.num_datasets * self.min_dataset_length diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py new file mode 100644 index 00000000..130898ec --- /dev/null +++ b/elk/extraction/prompt_loading.py @@ -0,0 +1,209 @@ +from ..math_util import stochastic_round_constrained +from ..promptsource import DatasetTemplates +from ..utils import ( + assert_type, + binarize, + infer_label_column, + infer_num_classes, + select_train_val_splits, +) +from dataclasses import dataclass +from datasets import ( + interleave_datasets, + load_dataset, + ClassLabel, + Features, + IterableDataset, + Sequence, +) +from datasets.distributed import split_dataset_by_node +from random import Random +from simple_parsing.helpers import field, Serializable +from typing import Any, Literal, Optional + + +@dataclass +class PromptConfig(Serializable): + """ + Args: + dataset: Space-delimited name of the HuggingFace dataset to use, e.g. + `"super_glue boolq"` or `"imdb"`. + balance: Whether to force class balance in the dataset using undersampling. + data_dir: The directory to use for caching the dataset. Defaults to + `~/.cache/huggingface/datasets`. + label_column: The column containing the labels. By default, we infer this from + the datatypes of the columns in the dataset; if there is only one column + with a `ClassLabel` datatype, we use that. + num_classes: The number of classes in the dataset. By default, we infer this + from the datatypes of the columns in the dataset; if there is only one + column with a `ClassLabel` datatype, we use the number of classes in that + column. + max_examples: The maximum number of examples to use from the val dataset. + If a single number, use at most that many examples for each split. If a list + of length 2, use the first element for the train split and the second for + the val split. If empty, use all examples. Defaults to empty. + num_shots: The number of examples to use in few-shot prompts. If zero, prompts + are zero-shot. Defaults to 0. + seed: The seed to use for prompt randomization. Defaults to 42. + num_variants: The number of prompt templates to apply to each predicate upon + call to __getitem__. Use -1 to apply all available templates. Defaults to 1. + """ + + datasets: list[str] = field(positional=True) + balance: bool = False + data_dir: Optional[str] = None + label_column: Optional[str] = None + num_classes: Optional[int] = None + max_examples: list[int] = field(default_factory=lambda: [750, 250]) + num_shots: int = 0 + num_variants: int = -1 + seed: int = 42 + + def __post_init__(self): + if len(self.max_examples) > 2: + raise ValueError( + "max_examples should be a list of length 0, 1, or 2," + f"but got {len(self.max_examples)}" + ) + + +def load_prompts( + *dataset_strings: str, + max_examples: int = 0, + seed: int = 42, + split_type: Literal["train", "val"] = "train", + rank: int = 0, + world_size: int = 1, +) -> IterableDataset: + """Load a dataset full of prompts generated from the specified datasets. + + Args: + dataset_strings: Space-delimited names of the HuggingFace datasets to use, + e.g. `"super_glue boolq"` or `"imdb"`. + max_examples: The maximum number of examples to use from the dataset. + seed: The seed to use for prompt randomization. + split_type: Whether to use the train or val split of the dataset. + rank: The rank of the current process. Defaults to 0. + world_size: The number of processes. Defaults to 1. + + Returns: + An iterable dataset of prompts. + """ + prompt_datasets = [] + prompters = [] + raw_datasets = [] + rng = Random(seed) + + # First load the datasets and prompters. We need to know the minimum number of + # templates for any dataset in order to make sure we don't run out of prompts. + for ds_string in dataset_strings: + ds_name, _, config_name = ds_string.partition(" ") + prompters.append(DatasetTemplates(ds_name, config_name)) + + ds_dict = assert_type( + dict, load_dataset(ds_name, config_name or None, streaming=True) + ) + train_name, val_name = select_train_val_splits(ds_dict) + split_name = val_name if split_type == "val" else train_name + raw_datasets.append(assert_type(IterableDataset, ds_dict[split_name])) + + num_variants = min(len(prompter.templates) for prompter in prompters) + for ds, prompter in zip(raw_datasets, prompters): + label_column = infer_label_column(ds.features) + num_classes = infer_num_classes(ds.features[label_column]) + assert num_classes == 2 + + # Remove everything but the label column + extra_cols = list(assert_type(Features, ds.features)) + extra_cols.remove(label_column) + + if label_column != "label": + ds = ds.rename_column(label_column, "label") + + # Canonicalize the name and dtype of the label column + ds = ds.map( + _convert_to_prompts, + fn_kwargs=dict( + label_column=label_column, + num_classes=num_classes, + num_variants=num_variants, + prompter=prompter, + rng=rng, + ), + remove_columns=extra_cols, + ).map( + # Add the builder and config name to the records directly to make + # sure we don't forget what dataset they came from. + lambda _: dict( + builder_name=ds.info.builder_name, + config_name=ds.info.config_name, + ), + # Explicit typing makes interleave_datasets work a lot faster + features=Features( + { + label_column: ClassLabel(names=["neg", "pos"]), + "builder_name": "string", + "config_name": "string", + "prompts": Sequence( + Sequence( + {"answer": "string", "text": "string"}, + length=num_classes, + ), + length=num_variants, + ), + "template_names": Sequence("string"), + } + ), + ) + prompt_datasets.append(ds) + + master_ds = interleave_datasets(prompt_datasets) + if max_examples > 0: + master_ds = master_ds.take(max_examples) + if world_size > 1: + master_ds = split_dataset_by_node(master_ds, rank, world_size) + + return master_ds + + +def _convert_to_prompts( + example: dict[str, Any], + prompter: DatasetTemplates, + label_column: str, + num_classes: int, + num_variants: int, + rng: Random, +) -> dict[str, Any]: + """Prompt-generating function to pass to `IterableDataset.map`.""" + prompts = [] + templates = list(prompter.templates.values()) + if num_variants < len(templates): + templates = rng.sample(templates, num_variants) + + for template in templates: + choices = [] + + for answer_idx in range(num_classes): + fake_example = example.copy() + fake_example[label_column] = answer_idx + + q, a = template.apply(fake_example) + + # if the jinja template already adds whitespace, don't add more + sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " " + text = f"{q}{sep}{a}" if a and not a.isspace() else q + choices.append( + dict( + # Strip whitespace from the answer to make it easier to + # compare with the model's output + answer=a.strip(), + text=text, + ) + ) + + prompts.append(choices) + + return dict( + prompts=prompts, + template_names=prompter.all_template_names, + ) diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index d49e7538..eb6dbc83 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -1,13 +1,10 @@ from .data_utils import ( binarize, - compute_class_balance, get_columns_all_equal, infer_label_column, infer_num_classes, - undersample, float32_to_int16, int16_to_float32, - apply_template, select_train_val_splits, ) from .gpu_utils import select_usable_devices diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 05ad8a13..6531c82e 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -2,43 +2,17 @@ from ..promptsource.templates import Template from datasets import ( ClassLabel, - Dataset, DatasetDict, Features, Split, Value, - concatenate_datasets, ) from random import Random -from typing import Optional, Iterable, Any -import numpy as np +from typing import Iterable, Any import torch import copy -def compute_class_balance( - dataset: Dataset, - num_classes: int, - label_column: Optional[str] = None, -) -> np.ndarray: - """Compute the class balance of a `Dataset`.""" - - features = dataset.features - name = dataset.info.builder_name - if label_column is None: - label_column = infer_label_column(dataset.features) - elif label_column not in features: - raise ValueError(f"{name} has no column '{label_column}'") - - class_sizes = np.bincount(dataset[label_column], minlength=num_classes) - - if not np.all(class_sizes > 0): - missing = np.flatnonzero(class_sizes == 0).tolist() - raise ValueError(f"{name} has missing classes: {missing}") - - return class_sizes - - def get_columns_all_equal(dataset: DatasetDict) -> list[str]: """Get columns of a `DatasetDict`, asserting all splits have the same columns.""" pivot, *rest = dataset.column_names.values() @@ -104,33 +78,6 @@ def infer_num_classes(label_feature: Any) -> int: ) -def undersample( - dataset: Dataset, rng: Random, num_classes: int, label_column: Optional[str] = None -) -> Dataset: - """Undersample a `Dataset` to the smallest class size.""" - label_column = label_column or infer_label_column(dataset.features) - class_sizes = compute_class_balance(dataset, num_classes, label_column) - smallest_size = class_sizes.min() - - # First group the active split by class - strata = ( - dataset.filter(lambda ex: ex[label_column] == i) - for i in range(len(class_sizes)) - ) - # Then randomly sample `smallest_size` examples from each class and merge - strata = [ - stratum.select(rng.sample(range(len(stratum)), k=smallest_size)) - for stratum in strata - ] - dataset = assert_type(Dataset, concatenate_datasets(strata)) - - # Sanity check that we successfully balanced the classes - class_sizes = np.bincount(dataset[label_column], minlength=len(class_sizes)) - assert np.all(class_sizes == smallest_size) - - return dataset - - def float32_to_int16(x: torch.Tensor) -> torch.Tensor: """Converts float32 to float16, then reinterprets as int16.""" return x.type(torch.float16).view(torch.int16) @@ -141,15 +88,6 @@ def int16_to_float32(x: torch.Tensor) -> torch.Tensor: return x.view(torch.float16).type(torch.float32) -def apply_template(template: Template, example: dict) -> str: - """Concatenate question and answer if answer is not empty or whitespace.""" - q, a = template.apply(example) - - # if the jinja template already adds whitespace, don't add more - sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " " - return f"{q}{sep}{a}" if a and not a.isspace() else q - - def binarize( template: Template, label: int, new_label: int, rng: Random ) -> tuple[Template, int]: diff --git a/pyproject.toml b/pyproject.toml index b93ea61d..e10aacfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,8 @@ requires-python = ">=3.9" keywords = ["nlp", "interpretability", "language-models", "explainable-ai"] license = {text = "MIT License"} dependencies = [ - # Added Dataset.from_generator() method - "datasets>=2.5.0", + # Added distributed.split_dataset_by_node for IterableDatasets + "datasets>=2.9.0", # Introduced numpy.typing module "numpy>=1.20.0", # This version is old, but it's needed for certain HF tokenizers to work. diff --git a/tests/test_prompt_dataset.py b/tests/test_prompt_dataset.py index e87158d8..91fb82d4 100644 --- a/tests/test_prompt_dataset.py +++ b/tests/test_prompt_dataset.py @@ -1,42 +1,46 @@ -from elk.extraction import ExtractionConfig, PromptDataset +from elk.extraction import load_prompts, ExtractionConfig from elk.promptsource.templates import DatasetTemplates import pytest -@pytest.mark.filterwarnings("ignore:Unable to find a decoding function") -def test_prompt_dataset_getitem_boolq(): - def test_prompt_dataset_getitem(cfg: ExtractionConfig, split: str): - prompt_ds = PromptDataset(cfg.prompts, rank=0, world_size=1, split=split) - ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") - - prompter = DatasetTemplates(ds_name, config_name or None) - assert len(prompt_ds) == cfg.prompts.max_examples[-1] - for i in range(len(prompt_ds)): - true_templates_ids = [ - template.id for template in prompter.templates.values() - ] - returned_prompts = prompt_ds[i] - returned_templates_ids = [prompt.template.id for prompt in returned_prompts] - - # check for using the right example - assert all( - [ - prompt_ds.active_split[i] == prompt.example - for prompt in returned_prompts - ] - ) - - # check for using the same templates - assert set(true_templates_ids) == set(returned_templates_ids) - # check for them being in the same order - assert true_templates_ids == returned_templates_ids - - # the case where the dataset has 2 classes - # this dataset is small - cfg = ExtractionConfig.load_yaml("tests/distilgpt2_copa_cfg.yaml") - test_prompt_dataset_getitem(cfg, "validation") - - # the case where the dataset has more than 2 classes - # TODO: I'm not sure if we want to force people to download the whole dataset - cfg = ExtractionConfig.load_yaml("tests/distilgpt2_dbpedia_cfg.yaml") - test_prompt_dataset_getitem(cfg, "test") +# @pytest.mark.filterwarnings("ignore:Unable to find a decoding function") +# def test_prompt_dataset_getitem_boolq(): +# def test_prompt_dataset_getitem(cfg: ExtractionConfig, split: str): +# prompt_ds = load_prompts( +# cfg.prompts.datasets[0], rank=0, world_size=1 +# ) +# ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") +# +# prompter = DatasetTemplates(ds_name, config_name or None) +# assert len(prompt_ds) == cfg.prompts.max_examples[-1] +# for i in range(len(prompt_ds)): +# true_templates_ids = [ +# template.id for template in prompter.templates.values() +# ] +# returned_prompts = prompt_ds[i] +# returned_templates_ids = [ +# prompt.template.id for prompt in returned_prompts +# ] +# +# # check for using the right example +# assert all( +# [ +# prompt_ds.active_split[i] == prompt.example +# for prompt in returned_prompts +# ] +# ) +# +# # check for using the same templates +# assert set(true_templates_ids) == set(returned_templates_ids) +# # check for them being in the same order +# assert true_templates_ids == returned_templates_ids +# +# # the case where the dataset has 2 classes +# # this dataset is small +# cfg = ExtractionConfig.load_yaml("tests/distilgpt2_copa_cfg.yaml") +# test_prompt_dataset_getitem(cfg, "validation") +# +# # the case where the dataset has more than 2 classes +# # TODO: I'm not sure if we want to force people to download the whole dataset +# cfg = ExtractionConfig.load_yaml("tests/distilgpt2_dbpedia_cfg.yaml") +# test_prompt_dataset_getitem(cfg, "test") diff --git a/tests/test_samplers.py b/tests/test_samplers.py new file mode 100644 index 00000000..30b86a49 --- /dev/null +++ b/tests/test_samplers.py @@ -0,0 +1,55 @@ +from collections import Counter +from datasets import load_dataset, IterableDataset +from elk.extraction import BalancedBatchSampler, BalancedSampler +from elk.utils import assert_type, infer_label_column +from itertools import islice +import numpy as np + + +def test_output_batches_are_balanced(): + # Load an example dataset for testing + dataset = assert_type( + IterableDataset, + load_dataset("super_glue", "boolq", split="train", streaming=True), + ) + label_col = infer_label_column(dataset.features) + + # Create the BalancedBatchSampler instance + batch_size = 32 + balanced_batch_sampler = BalancedBatchSampler(dataset, batch_size=batch_size) + + # Iterate through batches and check if they are balanced + for batch in balanced_batch_sampler: + counter = Counter(sample[label_col] for sample in batch) + + # Check if the output batch is balanced + label_0_count = counter[0] + label_1_count = counter[1] + assert ( + label_0_count == label_1_count + ), f"Batch is not balanced: {label_0_count}, {label_1_count}" + + +def test_output_is_roughly_balanced(): + # Load an example dataset for testing + dataset = assert_type( + IterableDataset, + load_dataset("super_glue", "boolq", split="train", streaming=True), + ) + + col = infer_label_column(dataset.features) + reservoir = BalancedSampler(dataset) + + # Count the number of samples for each label + counter = Counter() + for sample in islice(reservoir, 2000): + counter[sample[col]] += 1 + + # Check if the output is roughly balanced + label_0_count = counter[0] + label_1_count = counter[1] + imbalance = abs(label_0_count - label_1_count) / (label_0_count + label_1_count) + + # Set a tolerance threshold for the imbalance ratio (e.g., 1%) + tol = 0.01 + assert imbalance < tol, f"Imbalance ratio {imbalance} exceeded tolerance {tol}" From 03ba6e07e69344d4d63742235a5016dc73cf3969 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 22 Mar 2023 17:06:35 +0000 Subject: [PATCH 09/37] Skip slow BalancedBatchSampler test --- tests/test_samplers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 30b86a49..c25a25a1 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -3,9 +3,10 @@ from elk.extraction import BalancedBatchSampler, BalancedSampler from elk.utils import assert_type, infer_label_column from itertools import islice -import numpy as np +import pytest +@pytest.mark.skip(reason="This test is too slow to run on every commit") def test_output_batches_are_balanced(): # Load an example dataset for testing dataset = assert_type( From 15ab351ba1a78a65f539a0a725f82a62de789f96 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 22 Mar 2023 17:31:49 +0000 Subject: [PATCH 10/37] Slightly relax test_output_is_roughly_balanced --- tests/test_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index c25a25a1..841b3307 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -43,7 +43,7 @@ def test_output_is_roughly_balanced(): # Count the number of samples for each label counter = Counter() - for sample in islice(reservoir, 2000): + for sample in islice(reservoir, 3000): counter[sample[col]] += 1 # Check if the output is roughly balanced From a80369e586bdf8517de320af5d05e5df9d80d2eb Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 22 Mar 2023 17:43:09 +0000 Subject: [PATCH 11/37] Make BalancedSampler deterministic --- elk/extraction/balanced_sampler.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index c4da69b5..8f9d1bc6 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -1,5 +1,6 @@ from ..utils import infer_label_column from collections import Counter +from dataclasses import dataclass, field, InitVar from datasets import IterableDataset from itertools import cycle from torch.utils.data import IterableDataset as TorchIterableDataset @@ -7,6 +8,7 @@ import numpy as np +@dataclass class BalancedSampler(TorchIterableDataset): """ Approximately balances a binary classification dataset in a streaming fashion. @@ -22,24 +24,27 @@ class BalancedSampler(TorchIterableDataset): divided between the two binary label values (0 and 1). Defaults to 1000. """ - def __init__(self, dataset: IterableDataset): - self.dataset = dataset - self.class_counts = np.zeros(2) + dataset: IterableDataset + label_counts: np.ndarray = field(default_factory=lambda: np.zeros(2)) + seed: int = 42 + + def __post_init__(self): + self.rng = np.random.default_rng(self.seed) def __iter__(self): for sample in self.dataset: label = sample["label"] # Update class counts - self.class_counts[label] += 1 - current_balance = self.class_counts / self.class_counts.sum() + self.label_counts[label] += 1 + current_balance = self.label_counts / self.label_counts.sum() # Check if the sample should be dropped majority_class = np.argmax(current_balance) if label == majority_class: # Solution of n * p * q / [n * (1 - p) + n * p * q] = 0.5 for q keep_prob = 1 / current_balance[majority_class] - 1 - if np.random.rand() < 1 - keep_prob: + if self.rng.uniform() < 1 - keep_prob: continue yield sample From d304ab37f2bc188fa6124740960688b45e7f959d Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 22 Mar 2023 18:17:06 +0000 Subject: [PATCH 12/37] InitVar --- elk/extraction/balanced_sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 8f9d1bc6..0d1a7d42 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -26,10 +26,10 @@ class BalancedSampler(TorchIterableDataset): dataset: IterableDataset label_counts: np.ndarray = field(default_factory=lambda: np.zeros(2)) - seed: int = 42 + seed: InitVar[int] = 42 - def __post_init__(self): - self.rng = np.random.default_rng(self.seed) + def __post_init__(self, seed: int): + self.rng = np.random.default_rng(seed) def __iter__(self): for sample in self.dataset: From 761c82deda95d79cb4275d53f2a411b23861cb05 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 22 Mar 2023 23:26:23 +0000 Subject: [PATCH 13/37] Support multi class again --- elk/extraction/prompt_loading.py | 11 ++++---- tests/dbpedia_prompts.yaml | 10 +++++++ tests/distilgpt2_copa_cfg.yaml | 16 ----------- tests/distilgpt2_dbpedia_cfg.yaml | 16 ----------- tests/super_glue_prompts.yaml | 11 ++++++++ tests/test_load_prompts.py | 39 ++++++++++++++++++++++++++ tests/test_prompt_dataset.py | 46 ------------------------------- 7 files changed, 65 insertions(+), 84 deletions(-) create mode 100644 tests/dbpedia_prompts.yaml delete mode 100644 tests/distilgpt2_copa_cfg.yaml delete mode 100644 tests/distilgpt2_dbpedia_cfg.yaml create mode 100644 tests/super_glue_prompts.yaml create mode 100644 tests/test_load_prompts.py delete mode 100644 tests/test_prompt_dataset.py diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 130898ec..34ea79e4 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -34,10 +34,6 @@ class PromptConfig(Serializable): label_column: The column containing the labels. By default, we infer this from the datatypes of the columns in the dataset; if there is only one column with a `ClassLabel` datatype, we use that. - num_classes: The number of classes in the dataset. By default, we infer this - from the datatypes of the columns in the dataset; if there is only one - column with a `ClassLabel` datatype, we use the number of classes in that - column. max_examples: The maximum number of examples to use from the val dataset. If a single number, use at most that many examples for each split. If a list of length 2, use the first element for the train split and the second for @@ -53,7 +49,6 @@ class PromptConfig(Serializable): balance: bool = False data_dir: Optional[str] = None label_column: Optional[str] = None - num_classes: Optional[int] = None max_examples: list[int] = field(default_factory=lambda: [750, 250]) num_shots: int = 0 num_variants: int = -1 @@ -111,7 +106,6 @@ def load_prompts( for ds, prompter in zip(raw_datasets, prompters): label_column = infer_label_column(ds.features) num_classes = infer_num_classes(ds.features[label_column]) - assert num_classes == 2 # Remove everything but the label column extra_cols = list(assert_type(Features, ds.features)) @@ -183,6 +177,11 @@ def _convert_to_prompts( for template in templates: choices = [] + if num_classes > 2: + template, _ = binarize( + template, example[label_column], rng.choice([0, 1]), rng + ) + for answer_idx in range(num_classes): fake_example = example.copy() fake_example[label_column] = answer_idx diff --git a/tests/dbpedia_prompts.yaml b/tests/dbpedia_prompts.yaml new file mode 100644 index 00000000..76bc18b5 --- /dev/null +++ b/tests/dbpedia_prompts.yaml @@ -0,0 +1,10 @@ +balance: true +datasets: + - "dbpedia_14" +label_column: null +max_examples: +- 5 +- 5 +num_shots: 0 +num_variants: -1 +seed: 42 diff --git a/tests/distilgpt2_copa_cfg.yaml b/tests/distilgpt2_copa_cfg.yaml deleted file mode 100644 index cff7ec2f..00000000 --- a/tests/distilgpt2_copa_cfg.yaml +++ /dev/null @@ -1,16 +0,0 @@ -layers: [] -model: distilgpt2 -prompts: - balance: true - datasets: - - "super_glue copa" - label_column: null - max_examples: - - 2 - - 2 - num_classes: null - num_shots: 0 - num_variants: -1 - seed: 42 -token_loc: last -use_encoder_states: false diff --git a/tests/distilgpt2_dbpedia_cfg.yaml b/tests/distilgpt2_dbpedia_cfg.yaml deleted file mode 100644 index 629ab302..00000000 --- a/tests/distilgpt2_dbpedia_cfg.yaml +++ /dev/null @@ -1,16 +0,0 @@ -layers: [] -model: distilgpt2 -prompts: - balance: true - datasets: - - "dbpedia_14" - label_column: null - max_examples: - - 3 - - 5 - num_classes: null - num_shots: 0 - num_variants: -1 - seed: 42 -token_loc: last -use_encoder_states: false diff --git a/tests/super_glue_prompts.yaml b/tests/super_glue_prompts.yaml new file mode 100644 index 00000000..196267af --- /dev/null +++ b/tests/super_glue_prompts.yaml @@ -0,0 +1,11 @@ +balance: true +datasets: + - "super_glue boolq" + - "super_glue copa" +label_column: null +max_examples: +- 5 +- 5 +num_shots: 0 +num_variants: -1 +seed: 42 diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py new file mode 100644 index 00000000..51132618 --- /dev/null +++ b/tests/test_load_prompts.py @@ -0,0 +1,39 @@ +from elk.extraction import load_prompts, PromptConfig +from elk.promptsource.templates import DatasetTemplates +from itertools import cycle +from typing import Literal +import pytest + + +@pytest.mark.filterwarnings("ignore:Unable to find a decoding function") +def test_load_prompts(): + def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): + prompt_ds = load_prompts( + *cfg.datasets, max_examples=cfg.max_examples[0], split_type=split_type + ) + prompters = [] + + for ds in cfg.datasets: + ds_name, _, config_name = ds.partition(" ") + prompter = DatasetTemplates(ds_name, config_name or None) + prompters.append(prompter) + + for prompter, record in zip(cycle(prompters), prompt_ds): + true_template_names = prompter.all_template_names + returned_template_names = record["template_names"] + + # check for using the same templates + assert set(true_template_names) == set(returned_template_names) + # check for them being in the same order + assert true_template_names == true_template_names + + # the case where the dataset has 2 classes + # this dataset is small + cfg = PromptConfig.load_yaml("tests/super_glue_prompts.yaml") + test_single_split(cfg, "train") + test_single_split(cfg, "val") + + # the case where the dataset has more than 2 classes + cfg = PromptConfig.load_yaml("tests/dbpedia_prompts.yaml") + test_single_split(cfg, "train") + test_single_split(cfg, "val") diff --git a/tests/test_prompt_dataset.py b/tests/test_prompt_dataset.py deleted file mode 100644 index 91fb82d4..00000000 --- a/tests/test_prompt_dataset.py +++ /dev/null @@ -1,46 +0,0 @@ -from elk.extraction import load_prompts, ExtractionConfig -from elk.promptsource.templates import DatasetTemplates -import pytest - - -# @pytest.mark.filterwarnings("ignore:Unable to find a decoding function") -# def test_prompt_dataset_getitem_boolq(): -# def test_prompt_dataset_getitem(cfg: ExtractionConfig, split: str): -# prompt_ds = load_prompts( -# cfg.prompts.datasets[0], rank=0, world_size=1 -# ) -# ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ") -# -# prompter = DatasetTemplates(ds_name, config_name or None) -# assert len(prompt_ds) == cfg.prompts.max_examples[-1] -# for i in range(len(prompt_ds)): -# true_templates_ids = [ -# template.id for template in prompter.templates.values() -# ] -# returned_prompts = prompt_ds[i] -# returned_templates_ids = [ -# prompt.template.id for prompt in returned_prompts -# ] -# -# # check for using the right example -# assert all( -# [ -# prompt_ds.active_split[i] == prompt.example -# for prompt in returned_prompts -# ] -# ) -# -# # check for using the same templates -# assert set(true_templates_ids) == set(returned_templates_ids) -# # check for them being in the same order -# assert true_templates_ids == returned_templates_ids -# -# # the case where the dataset has 2 classes -# # this dataset is small -# cfg = ExtractionConfig.load_yaml("tests/distilgpt2_copa_cfg.yaml") -# test_prompt_dataset_getitem(cfg, "validation") -# -# # the case where the dataset has more than 2 classes -# # TODO: I'm not sure if we want to force people to download the whole dataset -# cfg = ExtractionConfig.load_yaml("tests/distilgpt2_dbpedia_cfg.yaml") -# test_prompt_dataset_getitem(cfg, "test") From f29743bf83c4f269db4377928e24881b0d129f9d Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 22 Mar 2023 23:31:47 +0000 Subject: [PATCH 14/37] Fix naming issue --- elk/training/train.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/elk/training/train.py b/elk/training/train.py index bcfc192d..42db21aa 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -70,14 +70,14 @@ def train_reporter( cfg: RunConfig, dataset: DatasetDict, out_dir: Path, - layer: list[int], + layers: list[int], devices: list[str], world_size: int = 1, ): """Train a single reporter on a single layer, or a list of layers.""" # Reproducibility - seed = cfg.net.seed + layer[0] + seed = cfg.net.seed + layers[0] np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) @@ -96,17 +96,17 @@ def train_reporter( train_labels = cast(Tensor, train["label"]) val_labels = cast(Tensor, val["label"]) - # concatenate hidden states across layers if multiple layers are inputted + # Concatenate hidden states across layers if multiple layers are requested train_hiddens = torch.cat( - [cast(Tensor, train[f"hidden_{lay}"]) for lay in layer], dim=-1 + [assert_type(Tensor, train[f"hidden_{layer}"]) for layer in layers], dim=-1 ) val_hiddens = torch.cat( - [cast(Tensor, val[f"hidden_{lay}"]) for lay in layer], dim=-1 + [assert_type(Tensor, val[f"hidden_{lay}"]) for lay in layers], dim=-1 ) train_h, val_h = normalize( - int16_to_float32(assert_type(Tensor, train_hiddens)), - int16_to_float32(assert_type(Tensor, val_hiddens)), + int16_to_float32(train_hiddens), + int16_to_float32(val_hiddens), method=cfg.normalization, ) @@ -119,7 +119,7 @@ def train_reporter( ) if pseudo_auroc > 0.6: warnings.warn( - f"The pseudo-labels at layers {layer} are linearly separable with " + f"The pseudo-labels at layers {layers} are linearly separable with " f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the " f"algorithm will not converge to a good solution." ) @@ -143,7 +143,7 @@ def train_reporter( lr_dir.mkdir(parents=True, exist_ok=True) reporter_dir.mkdir(parents=True, exist_ok=True) - layer_name = layer if isinstance(layer, int) else max(layer) + layer_name = max(layers) stats = [layer_name, pseudo_auroc, train_loss, *val_result] if not cfg.skip_baseline: @@ -170,10 +170,10 @@ def train_reporter( lr_auroc = roc_auc_score(val_labels_aug, lr_preds) stats += [lr_auroc, lr_acc] - with open(lr_dir / f"layer_{layer}.pt", "wb") as file: + with open(lr_dir / f"layer_{layers}.pt", "wb") as file: pickle.dump(lr_model, file) - with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: + with open(reporter_dir / f"layer_{layers}.pt", "wb") as file: torch.save(reporter, file) return stats From b7b7e23f858c1bb663698b46d3968a54bce21305 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 23 Mar 2023 01:00:59 +0000 Subject: [PATCH 15/37] Support few shot prompts --- elk/extraction/__init__.py | 2 +- elk/extraction/balanced_sampler.py | 51 ++++++++++++++++++------------ elk/extraction/prompt_loading.py | 40 +++++++++++++++++++---- tests/test_samplers.py | 28 ++++++++-------- 4 files changed, 80 insertions(+), 41 deletions(-) diff --git a/elk/extraction/__init__.py b/elk/extraction/__init__.py index 1d22e800..e3419679 100644 --- a/elk/extraction/__init__.py +++ b/elk/extraction/__init__.py @@ -1,4 +1,4 @@ -from .balanced_sampler import BalancedBatchSampler, BalancedSampler +from .balanced_sampler import BalancedSampler, FewShotSampler from .extraction import ExtractionConfig, extract_hiddens, extract from .generator import _GeneratorConfig, _GeneratorBuilder from .prompt_loading import PromptConfig, load_prompts diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 0d1a7d42..6dc28b97 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -1,8 +1,9 @@ from ..utils import infer_label_column -from collections import Counter +from ..math_util import stochastic_round_constrained from dataclasses import dataclass, field, InitVar from datasets import IterableDataset from itertools import cycle +from random import Random from torch.utils.data import IterableDataset as TorchIterableDataset from typing import Iterator, Optional import numpy as np @@ -50,39 +51,49 @@ def __iter__(self): yield sample -class BalancedBatchSampler: - """Yields precisely balanced batches from a binary classification dataset. +class FewShotSampler: + """Yields batches of few-shot examples that are as balanced as possible. - Written by a human being because GPT-4 couldn't figure out how to do it. + If the number of examples is divisible by the number of shots, this sampler + will yield batches of exactly `num_shots` examples. Otherwise, it will + use `stochastic_round_constrained` to get as close to balanced batches as + possible. """ def __init__( self, dataset: IterableDataset, + num_shots: int, + rng: Random, label_col: Optional[str] = None, - batch_size: int = 32, ): - self.batch_size = batch_size self.dataset = dataset self.label_col = label_col or infer_label_column(dataset.features) + self.num_shots = num_shots + self.rng = rng def __iter__(self) -> Iterator[list[dict]]: - batch = [] + neg_buf, pos_buf = [], [] - max_count = self.batch_size // 2 - label_counts = Counter() - - # Infinite loop! + # Infinite loop over the dataset! for sample in cycle(self.dataset): label = sample[self.label_col] - if label_counts[label] >= max_count: - continue - - batch.append(sample) - label_counts[label] += 1 + if label == 0: + neg_buf.append(sample) + elif label == 1: + pos_buf.append(sample) + else: + raise ValueError(f"Expected label to be 0 or 1, got {label}") + + neg_count, pos_count = stochastic_round_constrained( + [self.num_shots / 2, self.num_shots / 2], self.rng + ) + while len(neg_buf) >= neg_count and len(pos_buf) >= pos_count: + batch = [] + for _ in range(neg_count): + batch.append(neg_buf.pop()) + for _ in range(pos_count): + batch.append(pos_buf.pop()) - if len(batch) == self.batch_size: + self.rng.shuffle(batch) yield batch - - batch = [] - label_counts.clear() diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 34ea79e4..38d21212 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -1,4 +1,3 @@ -from ..math_util import stochastic_round_constrained from ..promptsource import DatasetTemplates from ..utils import ( assert_type, @@ -7,6 +6,7 @@ infer_num_classes, select_train_val_splits, ) +from .balanced_sampler import FewShotSampler from dataclasses import dataclass from datasets import ( interleave_datasets, @@ -19,7 +19,7 @@ from datasets.distributed import split_dataset_by_node from random import Random from simple_parsing.helpers import field, Serializable -from typing import Any, Literal, Optional +from typing import Any, Iterator, Literal, Optional @dataclass @@ -65,6 +65,7 @@ def __post_init__(self): def load_prompts( *dataset_strings: str, max_examples: int = 0, + num_shots: int = 0, seed: int = 42, split_type: Literal["train", "val"] = "train", rank: int = 0, @@ -76,6 +77,8 @@ def load_prompts( dataset_strings: Space-delimited names of the HuggingFace datasets to use, e.g. `"super_glue boolq"` or `"imdb"`. max_examples: The maximum number of examples to use from the dataset. + num_shots: The number of examples to use in few-shot prompts. If zero, prompts + are zero-shot. seed: The seed to use for prompt randomization. split_type: Whether to use the train or val split of the dataset. rank: The rank of the current process. Defaults to 0. @@ -87,6 +90,7 @@ def load_prompts( prompt_datasets = [] prompters = [] raw_datasets = [] + train_datasets = [] rng = Random(seed) # First load the datasets and prompters. We need to know the minimum number of @@ -101,9 +105,10 @@ def load_prompts( train_name, val_name = select_train_val_splits(ds_dict) split_name = val_name if split_type == "val" else train_name raw_datasets.append(assert_type(IterableDataset, ds_dict[split_name])) + train_datasets.append(assert_type(IterableDataset, ds_dict[train_name])) num_variants = min(len(prompter.templates) for prompter in prompters) - for ds, prompter in zip(raw_datasets, prompters): + for ds, train_ds, prompter in zip(raw_datasets, train_datasets, prompters): label_column = infer_label_column(ds.features) num_classes = infer_num_classes(ds.features[label_column]) @@ -113,6 +118,15 @@ def load_prompts( if label_column != "label": ds = ds.rename_column(label_column, "label") + if num_shots > 0: + fewshot = FewShotSampler( + train_ds, + num_shots=num_shots, + rng=rng, + ) + fewshot_iter = iter(fewshot) + else: + fewshot_iter = None # Canonicalize the name and dtype of the label column ds = ds.map( @@ -123,6 +137,7 @@ def load_prompts( num_variants=num_variants, prompter=prompter, rng=rng, + fewshot_iter=fewshot_iter, ), remove_columns=extra_cols, ).map( @@ -155,6 +170,7 @@ def load_prompts( if max_examples > 0: master_ds = master_ds.take(max_examples) if world_size > 1: + # This prints to stdout which is slightly annoying master_ds = split_dataset_by_node(master_ds, rank, world_size) return master_ds @@ -167,6 +183,7 @@ def _convert_to_prompts( num_classes: int, num_variants: int, rng: Random, + fewshot_iter: Optional[Iterator[list[dict]]] = None, ) -> dict[str, Any]: """Prompt-generating function to pass to `IterableDataset.map`.""" prompts = [] @@ -174,6 +191,11 @@ def _convert_to_prompts( if num_variants < len(templates): templates = rng.sample(templates, num_variants) + def qa_cat(q: str, a: str) -> str: + # if the jinja template already adds whitespace, don't add more + sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " " + return f"{q}{sep}{a}" if a and not a.isspace() else q + for template in templates: choices = [] @@ -187,10 +209,16 @@ def _convert_to_prompts( fake_example[label_column] = answer_idx q, a = template.apply(fake_example) + text = qa_cat(q, a) + + if fewshot_iter is not None: + # Infinite iterator so we don't need to worry about StopIteration + fewshot_examples = next(fewshot_iter) + fewshot_texts = [ + qa_cat(q, a) for q, a in map(template.apply, fewshot_examples) + ] + text = "\n\n".join(fewshot_texts) + "\n\n" + text - # if the jinja template already adds whitespace, don't add more - sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " " - text = f"{q}{sep}{a}" if a and not a.isspace() else q choices.append( dict( # Strip whitespace from the answer to make it easier to diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 841b3307..cb5e1225 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -1,12 +1,11 @@ from collections import Counter from datasets import load_dataset, IterableDataset -from elk.extraction import BalancedBatchSampler, BalancedSampler +from elk.extraction import FewShotSampler, BalancedSampler from elk.utils import assert_type, infer_label_column from itertools import islice -import pytest +from random import Random -@pytest.mark.skip(reason="This test is too slow to run on every commit") def test_output_batches_are_balanced(): # Load an example dataset for testing dataset = assert_type( @@ -15,20 +14,21 @@ def test_output_batches_are_balanced(): ) label_col = infer_label_column(dataset.features) - # Create the BalancedBatchSampler instance - batch_size = 32 - balanced_batch_sampler = BalancedBatchSampler(dataset, batch_size=batch_size) - - # Iterate through batches and check if they are balanced - for batch in balanced_batch_sampler: + # Start with an even number of shots; make sure they're exactly balanced + sampler = FewShotSampler(dataset, 6, rng=Random(42)) + for batch in islice(sampler, 5): counter = Counter(sample[label_col] for sample in batch) # Check if the output batch is balanced - label_0_count = counter[0] - label_1_count = counter[1] - assert ( - label_0_count == label_1_count - ), f"Batch is not balanced: {label_0_count}, {label_1_count}" + assert counter[0] == counter[1] + + # Start with an odd number of shots; make sure they're roughly balanced + sampler = FewShotSampler(dataset, 5, rng=Random(42)) + for batch in islice(sampler, 5): + counter = Counter(sample[label_col] for sample in batch) + + # The batch should be balanced to within 1 sample + assert abs(counter[0] - counter[1]) <= 1 def test_output_is_roughly_balanced(): From 225d4c79ba4527155ebc1b2b87dda7fd99204ae6 Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Thu, 23 Mar 2023 02:26:25 +0000 Subject: [PATCH 16/37] fix multiclass labels --- elk/extraction/prompt_loading.py | 13 ++++++++----- elk/utils/data_utils.py | 6 ++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 34ea79e4..a6954225 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -107,7 +107,7 @@ def load_prompts( label_column = infer_label_column(ds.features) num_classes = infer_num_classes(ds.features[label_column]) - # Remove everything but the label column + # Remove everything except the label column extra_cols = list(assert_type(Features, ds.features)) extra_cols.remove(label_column) @@ -141,7 +141,7 @@ def load_prompts( "prompts": Sequence( Sequence( {"answer": "string", "text": "string"}, - length=num_classes, + length=2, # contrast pair ), length=num_variants, ), @@ -174,15 +174,17 @@ def _convert_to_prompts( if num_variants < len(templates): templates = rng.sample(templates, num_variants) + new_label = rng.choice([0, 1]) if num_classes > 2 else example[label_column] + for template in templates: choices = [] if num_classes > 2: - template, _ = binarize( - template, example[label_column], rng.choice([0, 1]), rng + template = binarize( + template, example[label_column], assert_type(int, new_label), rng ) - for answer_idx in range(num_classes): + for answer_idx in range(2): fake_example = example.copy() fake_example[label_column] = answer_idx @@ -203,6 +205,7 @@ def _convert_to_prompts( prompts.append(choices) return dict( + label=new_label, prompts=prompts, template_names=prompter.all_template_names, ) diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 6531c82e..97e24c11 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -88,9 +88,7 @@ def int16_to_float32(x: torch.Tensor) -> torch.Tensor: return x.view(torch.float16).type(torch.float32) -def binarize( - template: Template, label: int, new_label: int, rng: Random -) -> tuple[Template, int]: +def binarize(template: Template, label: int, new_label: int, rng: Random) -> Template: """Binarize a template with >2 answer choices, returning a new template and label. Returns: @@ -117,4 +115,4 @@ def binarize( f"{false} ||| {true}" if new_label else f"{true} ||| {false}" ) - return new_template, new_label + return new_template From b1b95e5602d92315e7e1e772ed39164a93d5bfcc Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 25 Mar 2023 01:16:48 +0000 Subject: [PATCH 17/37] Fix dumb part of test failures --- tests/test_smoke_elicit.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index 3211271c..cdce5e0f 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -5,12 +5,6 @@ from elk.training import CcsReporterConfig, EigenReporterConfig from elk.training.train import train, RunConfig -""" -TODO: These tests should work with deberta -but you'll need to make deberta fp32 instead of fp16 -because pytorch cpu doesn't support fp16 -""" - def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): model_path = "sshleifer/tiny-gpt2" @@ -18,7 +12,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): config = RunConfig( data=ExtractionConfig( model=model_path, - prompts=PromptConfig(dataset=dataset_name, max_examples=[10]), + prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), # run on all layers, tiny-gpt only has 2 layers ), net=CcsReporterConfig(), @@ -33,23 +27,12 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): - """ - Currently this test fails with - u -= torch.einsum("...ij,...i->...j", V[..., :k, :], proj) - V[..., k, :] = F.normalize(u, dim=-1) - ~~~~~~~~~ <--- HERE - - u[:] = torch.einsum("...ij,...j->...i", A, V[..., k, :]) - - RuntimeError: select(): index 1 out of range for tensor of size [1, 2] - at dimension 0 - """ model_path = "sshleifer/tiny-gpt2" dataset_name = "imdb" config = RunConfig( data=ExtractionConfig( model=model_path, - prompts=PromptConfig(dataset=dataset_name, max_examples=[10]), + prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), # run on all layers, tiny-gpt only has 2 layers ), net=EigenReporterConfig(), From ee3911e262a757fcbed688d35716b0fc082ea6bf Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 25 Mar 2023 02:30:19 +0000 Subject: [PATCH 18/37] Fix assert_allclose warning --- tests/test_eigsh.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_eigsh.py b/tests/test_eigsh.py index dc206d90..b208dc23 100644 --- a/tests/test_eigsh.py +++ b/tests/test_eigsh.py @@ -20,7 +20,7 @@ def test_lanczos_eigsh(n, which): w_scipy, v_scipy = eigsh(A.numpy(), which=which) # Check that the eigenvalues match to within the tolerance - torch.testing.assert_allclose(w, torch.from_numpy(w_scipy), atol=1e-3, rtol=1e-3) + torch.testing.assert_close(w, torch.from_numpy(w_scipy), atol=1e-3, rtol=1e-3) # Normalize the sign of the eigenvectors for i in range(v.shape[-1]): @@ -30,4 +30,4 @@ def test_lanczos_eigsh(n, which): v_scipy[:, i] *= -1 # Check that the eigenvectors match to within the tolerance - torch.testing.assert_allclose(v, torch.from_numpy(v_scipy), atol=1e-3, rtol=1e-3) + torch.testing.assert_close(v, torch.from_numpy(v_scipy), atol=1e-3, rtol=1e-3) From a55b3de6e91793e7eb3361a7c810b67f6b23aed7 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 25 Mar 2023 02:31:55 +0000 Subject: [PATCH 19/37] Switch to torch.testing.assert_close in EigenReporter test --- tests/test_eigen_reporter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index 4564afc7..9d548f67 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -19,21 +19,21 @@ def test_eigen_reporter(): # Check that the streaming mean is correct pos_mu, neg_mu = x_pos.mean(dim=(0, 1)), x_neg.mean(dim=(0, 1)) - assert torch.allclose(reporter.pos_mean, pos_mu) - assert torch.allclose(reporter.neg_mean, neg_mu) + torch.testing.assert_close(reporter.pos_mean, pos_mu) + torch.testing.assert_close(reporter.neg_mean, neg_mu) # Check that the streaming covariance is correct pos_centroids, neg_centroids = x_pos.mean(dim=1), x_neg.mean(dim=1) expected_var = batch_cov(pos_centroids) + batch_cov(neg_centroids) - assert torch.allclose(reporter.intercluster_cov, expected_var) + torch.testing.assert_close(reporter.intercluster_cov, expected_var) # Check that the streaming invariance (intra-cluster variance) is correct expected_invariance = cov_mean_fused(x_pos) + cov_mean_fused(x_neg) - assert torch.allclose(reporter.intracluster_cov, expected_invariance) + torch.testing.assert_close(reporter.intracluster_cov, expected_invariance) # Check that the streaming negative covariance is correct cross_cov = (pos_centroids - pos_mu).mT @ (neg_centroids - neg_mu) / num_clusters cross_cov = cross_cov + cross_cov.mT - assert torch.allclose(reporter.contrastive_xcov, cross_cov) + torch.testing.assert_close(reporter.contrastive_xcov, cross_cov) assert reporter.n == num_clusters From 44dc25c75f21bc1bc614d73d4c53cfb06e9d679b Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 25 Mar 2023 02:37:16 +0000 Subject: [PATCH 20/37] Shuffle load_prompts output by default --- elk/extraction/prompt_loading.py | 3 +++ tests/test_load_prompts.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 870fcee2..82f549dc 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -67,6 +67,7 @@ def load_prompts( max_examples: int = 0, num_shots: int = 0, seed: int = 42, + shuffle: bool = True, split_type: Literal["train", "val"] = "train", rank: int = 0, world_size: int = 1, @@ -172,6 +173,8 @@ def load_prompts( if world_size > 1: # This prints to stdout which is slightly annoying master_ds = split_dataset_by_node(master_ds, rank, world_size) + if shuffle: + master_ds = master_ds.shuffle(seed=seed) return master_ds diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py index 51132618..678052e3 100644 --- a/tests/test_load_prompts.py +++ b/tests/test_load_prompts.py @@ -9,7 +9,10 @@ def test_load_prompts(): def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): prompt_ds = load_prompts( - *cfg.datasets, max_examples=cfg.max_examples[0], split_type=split_type + *cfg.datasets, + max_examples=cfg.max_examples[0], + shuffle=False, + split_type=split_type, ) prompters = [] From 93d8d87c3eaa7317b7ce1bcfc06a31d51b3be156 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sat, 25 Mar 2023 06:48:24 +0000 Subject: [PATCH 21/37] Fix smoke test failure --- elk/extraction/balanced_sampler.py | 35 +++++++++++++----------------- elk/extraction/extraction.py | 3 +++ elk/extraction/generator.py | 3 +++ elk/extraction/prompt_loading.py | 31 ++++++++++++++++++++++---- 4 files changed, 48 insertions(+), 24 deletions(-) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 6dc28b97..cbb2bd97 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -1,12 +1,12 @@ -from ..utils import infer_label_column from ..math_util import stochastic_round_constrained -from dataclasses import dataclass, field, InitVar +from ..utils import infer_label_column +from collections import deque +from dataclasses import dataclass from datasets import IterableDataset from itertools import cycle from random import Random from torch.utils.data import IterableDataset as TorchIterableDataset from typing import Iterator, Optional -import numpy as np @dataclass @@ -25,30 +25,25 @@ class BalancedSampler(TorchIterableDataset): divided between the two binary label values (0 and 1). Defaults to 1000. """ - dataset: IterableDataset - label_counts: np.ndarray = field(default_factory=lambda: np.zeros(2)) - seed: InitVar[int] = 42 + def __init__(self, dataset: IterableDataset, buffer_size: int = 1000): + self.dataset = dataset - def __post_init__(self, seed: int): - self.rng = np.random.default_rng(seed) + self.neg_buffer = deque(maxlen=buffer_size) + self.pos_buffer = deque(maxlen=buffer_size) def __iter__(self): for sample in self.dataset: label = sample["label"] - # Update class counts - self.label_counts[label] += 1 - current_balance = self.label_counts / self.label_counts.sum() - - # Check if the sample should be dropped - majority_class = np.argmax(current_balance) - if label == majority_class: - # Solution of n * p * q / [n * (1 - p) + n * p * q] = 0.5 for q - keep_prob = 1 / current_balance[majority_class] - 1 - if self.rng.uniform() < 1 - keep_prob: - continue + # Add the sample to the appropriate buffer + if label == 0: + self.neg_buffer.append(sample) + else: + self.pos_buffer.append(sample) - yield sample + while self.neg_buffer and self.pos_buffer: + yield self.neg_buffer.popleft() + yield self.pos_buffer.popleft() class FewShotSampler: diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 233e07f2..f0e186dc 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -96,6 +96,7 @@ def extract_hiddens( *cfg.prompts.datasets, max_examples=limits[0 if split_type == "train" else 1], split_type=split_type, + stream=cfg.prompts.stream, rank=rank, world_size=world_size, ) @@ -128,7 +129,9 @@ def extract_hiddens( layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) # print(f"Using {prompt_ds} variants for each dataset") + print("wowza") for example in BalancedSampler(prompt_ds): + print("holy crap") hidden_dict = { f"hidden_{layer_idx}": torch.empty( num_variants, diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index a581ea9c..1981bdd0 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -53,5 +53,8 @@ def _split_generators(self, dl_manager): def _generate_examples(self, **gen_kwargs): assert self.config.generator is not None, "generator must be specified" + + print("wow") for idx, ex in enumerate(self.config.generator(**gen_kwargs)): + print(f"iter {idx}") yield idx, ex diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 82f549dc..50ef9198 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -12,6 +12,7 @@ interleave_datasets, load_dataset, ClassLabel, + Dataset, Features, IterableDataset, Sequence, @@ -40,9 +41,10 @@ class PromptConfig(Serializable): the val split. If empty, use all examples. Defaults to empty. num_shots: The number of examples to use in few-shot prompts. If zero, prompts are zero-shot. Defaults to 0. - seed: The seed to use for prompt randomization. Defaults to 42. num_variants: The number of prompt templates to apply to each predicate upon call to __getitem__. Use -1 to apply all available templates. Defaults to 1. + seed: The seed to use for prompt randomization. Defaults to 42. + stream: Whether to stream the dataset from the Internet. Defaults to False. """ datasets: list[str] = field(positional=True) @@ -53,6 +55,7 @@ class PromptConfig(Serializable): num_shots: int = 0 num_variants: int = -1 seed: int = 42 + stream: bool = False def __post_init__(self): if len(self.max_examples) > 2: @@ -69,6 +72,7 @@ def load_prompts( seed: int = 42, shuffle: bool = True, split_type: Literal["train", "val"] = "train", + stream: bool = False, rank: int = 0, world_size: int = 1, ) -> IterableDataset: @@ -82,6 +86,7 @@ def load_prompts( are zero-shot. seed: The seed to use for prompt randomization. split_type: Whether to use the train or val split of the dataset. + stream: Whether to stream the dataset from the Internet. Defaults to False. rank: The rank of the current process. Defaults to 0. world_size: The number of processes. Defaults to 1. @@ -101,14 +106,26 @@ def load_prompts( prompters.append(DatasetTemplates(ds_name, config_name)) ds_dict = assert_type( - dict, load_dataset(ds_name, config_name or None, streaming=True) + dict, load_dataset(ds_name, config_name or None, streaming=stream) ) train_name, val_name = select_train_val_splits(ds_dict) split_name = val_name if split_type == "val" else train_name - raw_datasets.append(assert_type(IterableDataset, ds_dict[split_name])) - train_datasets.append(assert_type(IterableDataset, ds_dict[train_name])) + + # If we're not streaming, take the opportunity to shuffle the dataset. + if not stream: + ds = assert_type(Dataset, ds_dict[split_name].shuffle(seed=seed)) + train_ds = assert_type(Dataset, ds_dict[train_name].shuffle(seed=seed)) + split = ds.to_iterable_dataset().cast(ds.features) + else: + train_ds = assert_type(IterableDataset, ds_dict[train_name]) + split = assert_type(IterableDataset, ds_dict[split_name]) + + raw_datasets.append(split) + train_datasets.append(train_ds) num_variants = min(len(prompter.templates) for prompter in prompters) + assert num_variants > 0 + for ds, train_ds, prompter in zip(raw_datasets, train_datasets, prompters): label_column = infer_label_column(ds.features) num_classes = infer_num_classes(ds.features[label_column]) @@ -176,6 +193,11 @@ def load_prompts( if shuffle: master_ds = master_ds.shuffle(seed=seed) + # Try to approximately shuffle the dataset if we're streaming. Note that this is + # NOT an adequate shuffle for datasets like IMDB, which are sorted by label. + if stream: + master_ds = master_ds.shuffle(seed=seed) + return master_ds @@ -189,6 +211,7 @@ def _convert_to_prompts( fewshot_iter: Optional[Iterator[list[dict]]] = None, ) -> dict[str, Any]: """Prompt-generating function to pass to `IterableDataset.map`.""" + print(f"label: {example[label_column]}") prompts = [] templates = list(prompter.templates.values()) if num_variants < len(templates): From fad4d74957c391b9611dd78efd82e76446ff7e9c Mon Sep 17 00:00:00 2001 From: Alex Mallen <35092692+AlexTMallen@users.noreply.github.com> Date: Sat, 25 Mar 2023 09:41:54 -0700 Subject: [PATCH 22/37] Remove debug prints --- elk/extraction/generator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 1981bdd0..fbf10848 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -54,7 +54,5 @@ def _split_generators(self, dl_manager): def _generate_examples(self, **gen_kwargs): assert self.config.generator is not None, "generator must be specified" - print("wow") for idx, ex in enumerate(self.config.generator(**gen_kwargs)): - print(f"iter {idx}") yield idx, ex From 0a054f4cf0a6bd68c1302d4dfe2e916fb21e3d6b Mon Sep 17 00:00:00 2001 From: Alex Mallen <35092692+AlexTMallen@users.noreply.github.com> Date: Sat, 25 Mar 2023 09:44:23 -0700 Subject: [PATCH 23/37] Remove more debug print statements --- elk/extraction/extraction.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index f0e186dc..73063547 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -129,9 +129,7 @@ def extract_hiddens( layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) # print(f"Using {prompt_ds} variants for each dataset") - print("wowza") for example in BalancedSampler(prompt_ds): - print("holy crap") hidden_dict = { f"hidden_{layer_idx}": torch.empty( num_variants, From 177eec2e02f11366f70b9c9213a267bec170eb59 Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Sun, 26 Mar 2023 00:43:57 +0000 Subject: [PATCH 24/37] make min_memory usable; broadcast mmax_examples in __post_init__ --- elk/evaluation/evaluate.py | 2 +- elk/extraction/extraction.py | 17 ++++------------- elk/extraction/prompt_loading.py | 6 ++++++ elk/training/train.py | 2 +- elk/utils/gpu_utils.py | 10 +++++++--- tests/test_smoke_elicit.py | 14 ++++++++++---- 6 files changed, 29 insertions(+), 22 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 4dffa809..ca484db2 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -77,7 +77,7 @@ def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None): if feat.startswith("hidden_") ] - devices = select_usable_devices(cfg.num_gpus) + devices = select_usable_devices(cfg.num_gpus, min_memory=cfg.target.min_gpu_mem) num_devices = len(devices) transfer_eval = elk_reporter_dir() / cfg.source / "transfer_eval" diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 73063547..f8964af4 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -29,7 +29,7 @@ AutoTokenizer, PreTrainedModel, ) -from typing import Iterable, Literal, Union +from typing import Iterable, Literal, Union, Optional import logging import os import torch @@ -46,6 +46,7 @@ class ExtractionConfig(Serializable): layer_stride: Shortcut for setting `layers` to `range(0, num_layers, stride)`. token_loc: The location of the token to extract hidden states from. Can be either "first", "last", or "mean". Defaults to "last". + min_gpu_mem: Minimum amount of free memory (in bytes) required to select a GPU. """ prompts: PromptConfig @@ -54,6 +55,7 @@ class ExtractionConfig(Serializable): layers: tuple[int, ...] = () layer_stride: InitVar[int] = 1 token_loc: Literal["first", "last", "mean"] = "last" + min_gpu_mem: Optional[int] = None def __post_init__(self, layer_stride: int): if self.layers and layer_stride > 1: @@ -203,18 +205,7 @@ def get_splits() -> SplitDict: train_name, val_name = select_train_val_splits(available_splits) print(f"Using '{train_name}' for training and '{val_name}' for validation") - out_splits = SplitDict( - train=available_splits[train_name], val=available_splits[val_name] - ) - - # Empty list means no limit limit_list = cfg.prompts.max_examples - if not limit_list: - limit_list = [int(1e100)] - - # Broadcast the limit to all splits - if len(limit_list) == 1: - limit_list *= len(out_splits) return SplitDict( { @@ -255,7 +246,7 @@ def get_splits() -> SplitDict: length=num_variants, ), } - devices = select_usable_devices(num_gpus) + devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem) builders = { split_name: _GeneratorBuilder( cache_dir=None, diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 50ef9198..bacf7672 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -63,6 +63,12 @@ def __post_init__(self): "max_examples should be a list of length 0, 1, or 2," f"but got {len(self.max_examples)}" ) + if not self.max_examples: + self.max_examples = [int(1e100)] + + # Broadcast the limit to all splits + if len(self.max_examples) == 1: + self.max_examples *= 2 def load_prompts( diff --git a/elk/training/train.py b/elk/training/train.py index decc6686..7ba1e4c7 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -200,7 +200,7 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None): with open(out_dir / "metadata.yaml", "w") as meta_f: yaml.dump(meta, meta_f) - devices = select_usable_devices(cfg.num_gpus) + devices = select_usable_devices(cfg.num_gpus, min_memory=cfg.data.min_gpu_mem) num_devices = len(devices) cols = [ diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index 858781ab..9074bb8a 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -1,6 +1,7 @@ """Utilities that use PyNVML to get GPU usage info, and select GPUs accordingly.""" from .typing import assert_type +from typing import Optional import os import pynvml import torch @@ -8,8 +9,11 @@ import time -def select_usable_devices(num_gpus: int = -1, *, min_memory: int = -1) -> list[str]: +def select_usable_devices( + num_gpus: int = -1, *, min_memory: Optional[int] = None +) -> list[str]: """Select a set of devices that have at least `min_memory` bytes of free memory. + Blocks until at least `num_gpus` devices are available. When there are more than enough GPUs to satisfy the request, the GPUs with the most free memory will be selected. With default arguments, this function will @@ -30,7 +34,7 @@ def select_usable_devices(num_gpus: int = -1, *, min_memory: int = -1) -> list[s num_gpus: Number of GPUs to select. If negative, all available GPUs meeting the criteria will be selected. min_memory: Minimum amount of free memory (in bytes) required to select a GPU. - If negative, `min_memory` is set to 90% of the per-GPU memory. + If None, `min_memory` is set to 90% of the per-GPU memory. Returns: A list of suitable PyTorch device strings, in ascending numerical order, with @@ -85,7 +89,7 @@ def select_usable_devices(num_gpus: int = -1, *, min_memory: int = -1) -> list[s assert num_installed == num_visible, "PyNVML and PyTorch disagree on GPU count" # Set default value for `min_memory` - if min_memory < 0: + if min_memory is None: min_device_ram = min( ( assert_type( diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index cdce5e0f..424c2341 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -7,19 +7,22 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): - model_path = "sshleifer/tiny-gpt2" + # we need about 5 mb of gpu memory to run this test + model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 dataset_name = "imdb" config = RunConfig( data=ExtractionConfig( model=model_path, prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), + min_gpu_mem=min_mem, # run on all layers, tiny-gpt only has 2 layers ), + num_gpus=2, net=CcsReporterConfig(), ) train(config, tmp_path) # get the files in the tmp_path - files: Path = list(tmp_path.iterdir()) + files: list[Path] = list(tmp_path.iterdir()) created_file_names = {file.name for file in files} expected_files = ["cfg.yaml", "metadata.yaml", "lr_models", "reporters", "eval.csv"] for file in expected_files: @@ -27,19 +30,22 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): - model_path = "sshleifer/tiny-gpt2" + # we need about 5 mb of gpu memory to run this test + model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2 dataset_name = "imdb" config = RunConfig( data=ExtractionConfig( model=model_path, prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]), + min_gpu_mem=min_mem, # run on all layers, tiny-gpt only has 2 layers ), + num_gpus=2, net=EigenReporterConfig(), ) train(config, tmp_path) # get the files in the tmp_path - files: Path = list(tmp_path.iterdir()) + files: list[Path] = list(tmp_path.iterdir()) created_file_names = {file.name for file in files} expected_files = ["cfg.yaml", "metadata.yaml", "lr_models", "reporters", "eval.csv"] for file in expected_files: From 3a762b0dd995ed61ee61820214ee67f804647b35 Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Sun, 26 Mar 2023 03:02:24 +0000 Subject: [PATCH 25/37] prompt loading refactor to enable better streaming --- elk/extraction/balanced_sampler.py | 8 +- elk/extraction/extraction.py | 14 ++- elk/extraction/prompt_loading.py | 138 +++++++++++++---------------- tests/test_load_prompts.py | 6 +- 4 files changed, 75 insertions(+), 91 deletions(-) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index cbb2bd97..578771bf 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -6,7 +6,7 @@ from itertools import cycle from random import Random from torch.utils.data import IterableDataset as TorchIterableDataset -from typing import Iterator, Optional +from typing import Iterator, Optional, Iterable @dataclass @@ -25,14 +25,14 @@ class BalancedSampler(TorchIterableDataset): divided between the two binary label values (0 and 1). Defaults to 1000. """ - def __init__(self, dataset: IterableDataset, buffer_size: int = 1000): - self.dataset = dataset + def __init__(self, data: Iterable[dict], buffer_size: int = 1000): + self.data = data self.neg_buffer = deque(maxlen=buffer_size) self.pos_buffer = deque(maxlen=buffer_size) def __iter__(self): - for sample in self.dataset: + for sample in self.data: label = sample["label"] # Add the sample to the appropriate buffer diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index f8964af4..3e5c3651 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -33,6 +33,7 @@ import logging import os import torch +from itertools import islice @dataclass @@ -90,19 +91,14 @@ def extract_hiddens( # Silence datasets logging messages from all but the first process if rank != 0: logging.disable(logging.CRITICAL) - if rank == 0 and cfg.prompts.num_variants >= 1: - print(f"Using {cfg.prompts.num_variants} prompts per example") - limits = cfg.prompts.max_examples prompt_ds = load_prompts( *cfg.prompts.datasets, - max_examples=limits[0 if split_type == "train" else 1], split_type=split_type, stream=cfg.prompts.stream, rank=rank, world_size=world_size, - ) - num_variants = prompt_ds.features["prompts"].length + ) # this dataset is already sharded, but hasn't been truncated to max_examples # AutoModel should do the right thing here in nearly all cases. We don't actually # care what head the model has, since we are just extracting hidden states. @@ -131,7 +127,9 @@ def extract_hiddens( layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) # print(f"Using {prompt_ds} variants for each dataset") - for example in BalancedSampler(prompt_ds): + max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1] + for example in islice(BalancedSampler(prompt_ds), max_examples): + num_variants = len(example["prompts"]) hidden_dict = { f"hidden_{layer_idx}": torch.empty( num_variants, @@ -150,7 +148,7 @@ def extract_hiddens( # Iterate over answers for j in range(2): - text = record["text"][j] + text = record[j]["text"] variant_inputs.append(text) inputs = tokenizer( diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index bacf7672..eb064591 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -73,21 +73,20 @@ def __post_init__(self): def load_prompts( *dataset_strings: str, - max_examples: int = 0, num_shots: int = 0, + num_variants: int = -1, seed: int = 42, shuffle: bool = True, split_type: Literal["train", "val"] = "train", stream: bool = False, rank: int = 0, world_size: int = 1, -) -> IterableDataset: +) -> Iterator[dict]: """Load a dataset full of prompts generated from the specified datasets. Args: dataset_strings: Space-delimited names of the HuggingFace datasets to use, e.g. `"super_glue boolq"` or `"imdb"`. - max_examples: The maximum number of examples to use from the dataset. num_shots: The number of examples to use in few-shot prompts. If zero, prompts are zero-shot. seed: The seed to use for prompt randomization. @@ -99,7 +98,6 @@ def load_prompts( Returns: An iterable dataset of prompts. """ - prompt_datasets = [] prompters = [] raw_datasets = [] train_datasets = [] @@ -117,94 +115,82 @@ def load_prompts( train_name, val_name = select_train_val_splits(ds_dict) split_name = val_name if split_type == "val" else train_name - # If we're not streaming, take the opportunity to shuffle the dataset. + # Note that when streaming we can only approximately shuffle the dataset + # using a buffer. Streaming shuffling is NOT an adequate shuffle for + # datasets like IMDB, which are sorted by label. + bad_streaming_datasets = ["imdb"] + assert not ( + stream and ds_name in bad_streaming_datasets + ), f"Streaming is not supported for {ds_name}." + split = ds_dict[split_name].shuffle(seed=seed) + train_ds = ds_dict[train_name].shuffle(seed=seed) if not stream: - ds = assert_type(Dataset, ds_dict[split_name].shuffle(seed=seed)) - train_ds = assert_type(Dataset, ds_dict[train_name].shuffle(seed=seed)) - split = ds.to_iterable_dataset().cast(ds.features) - else: - train_ds = assert_type(IterableDataset, ds_dict[train_name]) - split = assert_type(IterableDataset, ds_dict[split_name]) + split = assert_type(Dataset, split) + split = split.to_iterable_dataset().cast(split.features) + + # only keep the datapoints relevant to the current process + if world_size > 1: + # This prints to stdout which is slightly annoying + split = split_dataset_by_node(split, world_size, rank) raw_datasets.append(split) train_datasets.append(train_ds) - num_variants = min(len(prompter.templates) for prompter in prompters) + min_num_templates = min(len(prompter.templates) for prompter in prompters) + num_variants = ( + min_num_templates + if num_variants == -1 + else min(num_variants, min_num_templates) + ) assert num_variants > 0 + if rank == 0: + print(f"Using {num_variants} variants of each prompt") + + ds_iterators = [iter(ds) for ds in raw_datasets] + while True: # terminates when the first dataset runs out of examples + for ds_iterator, ds, train_ds, prompter in zip( + ds_iterators, raw_datasets, train_datasets, prompters + ): + label_column = infer_label_column(ds.features) + num_classes = infer_num_classes(ds.features[label_column]) + + # Remove everything except the label column + extra_cols = list(assert_type(Features, ds.features)) + extra_cols.remove(label_column) + + if label_column != "label": + ds = ds.rename_column(label_column, "label") + if num_shots > 0: + fewshot = FewShotSampler( + train_ds, # TODO: not iterator + num_shots=num_shots, + rng=rng, + ) + fewshot_iter = iter(fewshot) + else: + fewshot_iter = None - for ds, train_ds, prompter in zip(raw_datasets, train_datasets, prompters): - label_column = infer_label_column(ds.features) - num_classes = infer_num_classes(ds.features[label_column]) - - # Remove everything except the label column - extra_cols = list(assert_type(Features, ds.features)) - extra_cols.remove(label_column) + try: + example = next(ds_iterator) + except StopIteration: + return - if label_column != "label": - ds = ds.rename_column(label_column, "label") - if num_shots > 0: - fewshot = FewShotSampler( - train_ds, - num_shots=num_shots, - rng=rng, - ) - fewshot_iter = iter(fewshot) - else: - fewshot_iter = None - - # Canonicalize the name and dtype of the label column - ds = ds.map( - _convert_to_prompts, - fn_kwargs=dict( + example = _convert_to_prompts( + example, label_column=label_column, num_classes=num_classes, num_variants=num_variants, prompter=prompter, rng=rng, fewshot_iter=fewshot_iter, - ), - remove_columns=extra_cols, - ).map( + ) + # Add the builder and config name to the records directly to make # sure we don't forget what dataset they came from. - lambda _: dict( - builder_name=ds.info.builder_name, - config_name=ds.info.config_name, - ), - # Explicit typing makes interleave_datasets work a lot faster - features=Features( - { - label_column: ClassLabel(names=["neg", "pos"]), - "builder_name": "string", - "config_name": "string", - "prompts": Sequence( - Sequence( - {"answer": "string", "text": "string"}, - length=2, # contrast pair - ), - length=num_variants, - ), - "template_names": Sequence("string"), - } - ), - ) - prompt_datasets.append(ds) - - master_ds = interleave_datasets(prompt_datasets) - if max_examples > 0: - master_ds = master_ds.take(max_examples) - if world_size > 1: - # This prints to stdout which is slightly annoying - master_ds = split_dataset_by_node(master_ds, rank, world_size) - if shuffle: - master_ds = master_ds.shuffle(seed=seed) - - # Try to approximately shuffle the dataset if we're streaming. Note that this is - # NOT an adequate shuffle for datasets like IMDB, which are sorted by label. - if stream: - master_ds = master_ds.shuffle(seed=seed) - - return master_ds + example["builder_name"] = ds.info.builder_name + example["config_name"] = ds.info.config_name + + yield example def _convert_to_prompts( diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py index 678052e3..15159870 100644 --- a/tests/test_load_prompts.py +++ b/tests/test_load_prompts.py @@ -1,6 +1,6 @@ from elk.extraction import load_prompts, PromptConfig from elk.promptsource.templates import DatasetTemplates -from itertools import cycle +from itertools import cycle, islice from typing import Literal import pytest @@ -10,7 +10,6 @@ def test_load_prompts(): def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): prompt_ds = load_prompts( *cfg.datasets, - max_examples=cfg.max_examples[0], shuffle=False, split_type=split_type, ) @@ -21,7 +20,8 @@ def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): prompter = DatasetTemplates(ds_name, config_name or None) prompters.append(prompter) - for prompter, record in zip(cycle(prompters), prompt_ds): + limit = cfg.max_examples[0 if split_type == "train" else 1] + for prompter, record in zip(cycle(prompters), islice(prompt_ds, limit)): true_template_names = prompter.all_template_names returned_template_names = record["template_names"] From f66c054b498d202777b66e34c520576ac8fdd7ed Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Sun, 26 Mar 2023 03:23:50 +0000 Subject: [PATCH 26/37] remove shuffle arg --- elk/extraction/prompt_loading.py | 1 - tests/test_load_prompts.py | 1 - 2 files changed, 2 deletions(-) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index eb064591..16ba415e 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -76,7 +76,6 @@ def load_prompts( num_shots: int = 0, num_variants: int = -1, seed: int = 42, - shuffle: bool = True, split_type: Literal["train", "val"] = "train", stream: bool = False, rank: int = 0, diff --git a/tests/test_load_prompts.py b/tests/test_load_prompts.py index 15159870..c9a45f03 100644 --- a/tests/test_load_prompts.py +++ b/tests/test_load_prompts.py @@ -10,7 +10,6 @@ def test_load_prompts(): def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]): prompt_ds = load_prompts( *cfg.datasets, - shuffle=False, split_type=split_type, ) prompters = [] From d3d87fcb07db41ca6066e88e1794966c4d5c6ed6 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Sun, 26 Mar 2023 08:33:57 -0700 Subject: [PATCH 27/37] remove unused @dataclass --- elk/extraction/balanced_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/elk/extraction/balanced_sampler.py b/elk/extraction/balanced_sampler.py index 578771bf..9e41c582 100644 --- a/elk/extraction/balanced_sampler.py +++ b/elk/extraction/balanced_sampler.py @@ -9,7 +9,6 @@ from typing import Iterator, Optional, Iterable -@dataclass class BalancedSampler(TorchIterableDataset): """ Approximately balances a binary classification dataset in a streaming fashion. From c9a43e1247bb5c0ba106846d650b7f24ee553873 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Mar 2023 22:30:12 +0000 Subject: [PATCH 28/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/evaluation/evaluate.py | 4 +++- elk/extraction/__init__.py | 2 +- elk/extraction/extraction.py | 21 +++++++++++++++++---- elk/run.py | 4 +++- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 87794728..753246f4 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -97,7 +97,9 @@ def evaluate_reporter( def evaluate(self): """Evaluate the reporter on all layers.""" - devices = select_usable_devices(self.cfg.num_gpus, min_memory=cfg.target.min_gpu_mem) + devices = select_usable_devices( + self.cfg.num_gpus, min_memory=cfg.target.min_gpu_mem + ) num_devices = len(devices) func: Callable[[int], EvalLog] = partial( diff --git a/elk/extraction/__init__.py b/elk/extraction/__init__.py index 3337b3ef..c9bd8ba1 100644 --- a/elk/extraction/__init__.py +++ b/elk/extraction/__init__.py @@ -1,4 +1,4 @@ from .balanced_sampler import BalancedSampler, FewShotSampler from .extraction import Extract, extract_hiddens, extract from .generator import _GeneratorConfig, _GeneratorBuilder -from .prompt_loading import PromptConfig, load_prompts \ No newline at end of file +from .prompt_loading import PromptConfig, load_prompts diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index a4dacc45..e4ee740d 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -9,12 +9,25 @@ from simple_parsing import Serializable, field from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel -from datasets import (Array3D, ClassLabel, DatasetDict, Features, Sequence, - SplitDict, SplitInfo, Value, get_dataset_config_info) +from datasets import ( + Array3D, + ClassLabel, + DatasetDict, + Features, + Sequence, + SplitDict, + SplitInfo, + Value, + get_dataset_config_info, +) from elk.utils.typing import float32_to_int16 -from ..utils import (assert_type, infer_label_column, select_train_val_splits, - select_usable_devices) +from ..utils import ( + assert_type, + infer_label_column, + select_train_val_splits, + select_usable_devices, +) from .balanced_sampler import BalancedSampler from .generator import _GeneratorBuilder from .prompt_loading import PromptConfig, load_prompts diff --git a/elk/run.py b/elk/run.py index f2fd07b6..142a7607 100644 --- a/elk/run.py +++ b/elk/run.py @@ -93,7 +93,9 @@ def prepare_data( def concatenate(self, layers): """Concatenate hidden states from a previous layer.""" for layer in range(self.cfg.concatenated_layer_offset, len(layers)): - layers[layer] = layers[layer] + [layers[layer][0] - cfg.concatenated_layer_offset] + layers[layer] = layers[layer] + [ + layers[layer][0] - cfg.concatenated_layer_offset + ] return layers def apply_to_layers( From 94290aa613b4c16139e19f8cdbded928ad9b09c9 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Mon, 27 Mar 2023 22:38:16 +0000 Subject: [PATCH 29/37] add concatenated_layer_offset to eval --- elk/evaluation/evaluate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 87794728..c6a67ac0 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -56,6 +56,8 @@ class Eval(Serializable): out_dir: Optional[Path] = None num_gpus: int = -1 + concatenated_layer_offset: int = 0 + def execute(self): transfer_eval = elk_reporter_dir() / self.source / "transfer_eval" From 3765c4ff057f36d2a10a166a12b22dfa1b3b8608 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Mon, 27 Mar 2023 22:41:25 +0000 Subject: [PATCH 30/37] add self. --- elk/evaluation/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 4a385920..6f4245d9 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -100,7 +100,7 @@ def evaluate_reporter( def evaluate(self): """Evaluate the reporter on all layers.""" devices = select_usable_devices( - self.cfg.num_gpus, min_memory=cfg.target.min_gpu_mem + self.cfg.num_gpus, min_memory=self.cfg.target.min_gpu_mem ) num_devices = len(devices) From 2b051933d75a0ba1d9fc8e79ac542fa8a214a4c5 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Mon, 27 Mar 2023 22:55:35 +0000 Subject: [PATCH 31/37] replace target with data --- elk/evaluation/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 6f4245d9..1e756894 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -100,7 +100,7 @@ def evaluate_reporter( def evaluate(self): """Evaluate the reporter on all layers.""" devices = select_usable_devices( - self.cfg.num_gpus, min_memory=self.cfg.target.min_gpu_mem + self.cfg.num_gpus, min_memory=self.cfg.data.min_gpu_mem ) num_devices = len(devices) From 83731bbf3817ebf9708edb9b76adc605ab1730e8 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Mon, 27 Mar 2023 22:57:04 +0000 Subject: [PATCH 32/37] add self. --- elk/run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elk/run.py b/elk/run.py index 142a7607..2ed3666b 100644 --- a/elk/run.py +++ b/elk/run.py @@ -94,7 +94,7 @@ def concatenate(self, layers): """Concatenate hidden states from a previous layer.""" for layer in range(self.cfg.concatenated_layer_offset, len(layers)): layers[layer] = layers[layer] + [ - layers[layer][0] - cfg.concatenated_layer_offset + layers[layer][0] - self.cfg.concatenated_layer_offset ] return layers @@ -120,7 +120,7 @@ def apply_to_layers( layers: list[int] = get_layers(self.dataset) - if cfg.concatenated_layer_offset > 0: + if self.cfg.concatenated_layer_offset > 0: layers = self.concatenate(cfg, layers) # Should we write to different CSV files for elicit vs eval? From 764fda96b1a331aa7d9d0f62427651fe08a3f052 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Mon, 27 Mar 2023 22:57:47 +0000 Subject: [PATCH 33/37] remove second arg --- elk/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elk/run.py b/elk/run.py index 2ed3666b..7b25ccd7 100644 --- a/elk/run.py +++ b/elk/run.py @@ -121,7 +121,7 @@ def apply_to_layers( layers: list[int] = get_layers(self.dataset) if self.cfg.concatenated_layer_offset > 0: - layers = self.concatenate(cfg, layers) + layers = self.concatenate(layers) # Should we write to different CSV files for elicit vs eval? with mp.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f: From d2c66b0c1405d8a3a303a13205a5ecd497f9e41a Mon Sep 17 00:00:00 2001 From: James Chua Date: Tue, 28 Mar 2023 22:52:39 +0800 Subject: [PATCH 34/37] fix passing the wrong params for world size / rank --- elk/extraction/prompt_loading.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 16ba415e..74ef71b6 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -130,7 +130,9 @@ def load_prompts( # only keep the datapoints relevant to the current process if world_size > 1: # This prints to stdout which is slightly annoying - split = split_dataset_by_node(split, world_size, rank) + split = split_dataset_by_node( + dataset=split, rank=rank, world_size=world_size + ) raw_datasets.append(split) train_datasets.append(train_ds) From 918632683c8ecb5de3c815d79e2b66ff8009440f Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Tue, 28 Mar 2023 10:48:15 -0700 Subject: [PATCH 35/37] Update prompt_loading.py Remove print label --- elk/extraction/prompt_loading.py | 1 - 1 file changed, 1 deletion(-) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 74ef71b6..0117829f 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -204,7 +204,6 @@ def _convert_to_prompts( fewshot_iter: Optional[Iterator[list[dict]]] = None, ) -> dict[str, Any]: """Prompt-generating function to pass to `IterableDataset.map`.""" - print(f"label: {example[label_column]}") prompts = [] templates = list(prompter.templates.values()) if num_variants < len(templates): From 3f99a4d1e1c7636d7ee3cd092601cd99ddd5d328 Mon Sep 17 00:00:00 2001 From: Walter Laurito Date: Tue, 28 Mar 2023 22:34:28 +0000 Subject: [PATCH 36/37] fix pre-commit errors --- elk/evaluation/evaluate.py | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 1e756894..18662f49 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,34 +1,26 @@ -from ..training.preprocessing import normalize +import csv +import os from dataclasses import dataclass -from datasets import DatasetDict from functools import partial from pathlib import Path +from typing import Callable, Literal, Optional, cast + +import torch +import torch.multiprocessing as mp from simple_parsing.helpers import Serializable, field from torch import Tensor from tqdm.auto import tqdm -from typing import Literal, Optional, cast, Callable -import csv -import os -import torch -import torch.multiprocessing as mp +from datasets import DatasetDict +from elk.evaluation.evaluate_log import EvalLog from elk.extraction.extraction import Extract -from ..files import elk_reporter_dir, memorably_named_dir -from ..utils import ( - assert_type, - int16_to_float32, - select_train_val_splits, - select_usable_devices, -) - -import torch -from simple_parsing import Serializable, field - -from elk.files import elk_reporter_dir from elk.run import Run from elk.training import Reporter -from elk.evaluation.evaluate_log import EvalLog -from elk.utils import select_usable_devices + +from ..files import elk_reporter_dir, memorably_named_dir +from ..training.preprocessing import normalize +from ..utils import (assert_type, int16_to_float32, select_train_val_splits, + select_usable_devices) @dataclass From 148130d2ae099ddd51b57ddd3343837b4cdda28d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Mar 2023 22:35:17 +0000 Subject: [PATCH 37/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/evaluation/evaluate.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 18662f49..c26aabf8 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -19,8 +19,12 @@ from ..files import elk_reporter_dir, memorably_named_dir from ..training.preprocessing import normalize -from ..utils import (assert_type, int16_to_float32, select_train_val_splits, - select_usable_devices) +from ..utils import ( + assert_type, + int16_to_float32, + select_train_val_splits, + select_usable_devices, +) @dataclass