Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi datasets #123

Merged
merged 46 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
681698d
add multiple datasets support
Benw8888 Mar 9, 2023
ac1b9f1
Merge branch 'main' of github.com:EleutherAI/elk into multi-datasets
Benw8888 Mar 9, 2023
b864c77
train_reporter works on a list of layers now
Benw8888 Mar 10, 2023
7d7d97c
changing printed layer names
Benw8888 Mar 10, 2023
4fe61e9
fixed concatenation bug
Benw8888 Mar 11, 2023
fe61d67
minor edits
Benw8888 Mar 13, 2023
74da878
fixed pyright issues
Benw8888 Mar 13, 2023
569ef05
Merge branch 'main' of github.com:EleutherAI/elk into multi-datasets
Benw8888 Mar 14, 2023
b62b679
Merge branch 'main' into multi-datasets
norabelrose Mar 20, 2023
fe94c22
Fix tests
norabelrose Mar 20, 2023
bba24d8
Now working sorta
norabelrose Mar 22, 2023
03ba6e0
Skip slow BalancedBatchSampler test
norabelrose Mar 22, 2023
15ab351
Slightly relax test_output_is_roughly_balanced
norabelrose Mar 22, 2023
a80369e
Make BalancedSampler deterministic
norabelrose Mar 22, 2023
d304ab3
InitVar
norabelrose Mar 22, 2023
761c82d
Support multi class again
norabelrose Mar 22, 2023
f29743b
Fix naming issue
norabelrose Mar 22, 2023
b7b7e23
Support few shot prompts
norabelrose Mar 23, 2023
1afb563
Merge branch 'main' into multi-datasets
norabelrose Mar 23, 2023
225d4c7
fix multiclass labels
AlexTMallen Mar 23, 2023
9368dc8
Merge branch 'multi-datasets' of github.com:EleutherAI/elk into multi…
AlexTMallen Mar 23, 2023
a858b65
Merge branch 'main' into multi-datasets
norabelrose Mar 24, 2023
5dc2ec6
Merge branch 'multi-datasets' of github.com:EleutherAI/elk into multi…
norabelrose Mar 24, 2023
b1b95e5
Fix dumb part of test failures
norabelrose Mar 25, 2023
ee3911e
Fix assert_allclose warning
norabelrose Mar 25, 2023
a55b3de
Switch to torch.testing.assert_close in EigenReporter test
norabelrose Mar 25, 2023
44dc25c
Shuffle load_prompts output by default
norabelrose Mar 25, 2023
93d8d87
Fix smoke test failure
norabelrose Mar 25, 2023
fad4d74
Remove debug prints
AlexTMallen Mar 25, 2023
0a054f4
Remove more debug print statements
AlexTMallen Mar 25, 2023
177eec2
make min_memory usable; broadcast mmax_examples in __post_init__
AlexTMallen Mar 26, 2023
3a762b0
prompt loading refactor to enable better streaming
AlexTMallen Mar 26, 2023
f66c054
remove shuffle arg
AlexTMallen Mar 26, 2023
d3d87fc
remove unused @dataclass
lauritowal Mar 26, 2023
3d08147
merge
lauritowal Mar 27, 2023
c9a43e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 27, 2023
94290aa
add concatenated_layer_offset to eval
lauritowal Mar 27, 2023
f9298e4
Merge branch 'multi-datasets' of https://github.com/EleutherAI/elk in…
lauritowal Mar 27, 2023
3765c4f
add self.
lauritowal Mar 27, 2023
2b05193
replace target with data
lauritowal Mar 27, 2023
83731bb
add self.
lauritowal Mar 27, 2023
764fda9
remove second arg
lauritowal Mar 27, 2023
d2c66b0
fix passing the wrong params for world size / rank
thejaminator Mar 28, 2023
9186326
Update prompt_loading.py
lauritowal Mar 28, 2023
3f99a4d
fix pre-commit errors
lauritowal Mar 28, 2023
148130d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add multiple datasets support
  • Loading branch information
Benw8888 committed Mar 9, 2023
commit 681698d08173d1bd673facb44b6b312ed94b33a6
37 changes: 26 additions & 11 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions for extracting the hidden states of a model."""

from .prompt_dataset import Prompt, PromptDataset, PromptConfig
from .prompt_dataset import Prompt, PromptDataset, PromptConfig, Interleaved_Datasets
from ..utils import (
assert_type,
infer_label_column,
Expand All @@ -19,6 +19,7 @@
SplitDict,
SplitInfo,
Value,
interleave_datasets,
)
from simple_parsing.helpers import field, Serializable
from transformers import (
Expand Down Expand Up @@ -88,18 +89,27 @@ def extract_hiddens(
if rank != 0:
logging.disable(logging.CRITICAL)

prompt_ds = PromptDataset(cfg.prompts, rank, world_size, split)
if rank == 0:
prompt_names = prompt_ds.prompter.all_template_names
if cfg.prompts.num_variants >= 1:
print(
f"Using {cfg.prompts.num_variants} prompts per example: {prompt_names}"
)
print(f"Using {cfg.prompts.num_variants} prompts per example")
elif cfg.prompts.num_variants == -1:
print(f"Using all prompts per example: {prompt_names}")
print("Using all prompts per example")
else:
raise ValueError(f"Invalid prompt num_variants: {cfg.prompts.num_variants}")

prompt_datasets = []

# create a PromptDataset for each dataset in cfg.prompts
for dataset_index in range(len(cfg.prompts.datasets)):
dataset_name = cfg.prompts.datasets[dataset_index]
prompt_ds = PromptDataset(cfg.prompts, rank, world_size, split, dataset_index)
prompt_names = prompt_ds.prompter.all_template_names
print(f"Prompts for dataset {dataset_name}: {prompt_names}")
prompt_datasets.append(prompt_ds)

# combine each PromptDataset together, interleaving them
interleaved_prompt_datasets = Interleaved_Datasets(prompt_datasets)

# AutoModel should do the right thing here in nearly all cases. We don't actually
# care what head the model has, since we are just extracting hidden states.
model = AutoModel.from_pretrained(cfg.model, torch_dtype="auto").to(device)
Expand All @@ -114,7 +124,10 @@ def extract_hiddens(

# TODO: Make this configurable or something
# Token used to separate the question from the answer
num_choices = prompt_ds.num_classes
num_choices = prompt_datasets[0].num_classes
for i in range(1, len(prompt_datasets)):
assert prompt_datasets[i].num_classes == num_choices

sep_token = tokenizer.sep_token or "\n"

if not tokenizer.pad_token:
norabelrose marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -160,7 +173,8 @@ def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]:

# Iterating over questions
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
for prompts in prompt_ds:

for prompts in interleaved_prompt_datasets:
inputs = collate(prompts)
hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
Expand Down Expand Up @@ -228,7 +242,7 @@ def get_splits() -> SplitDict:
{
k: SplitInfo(
name=k,
num_examples=min(limit, v.num_examples),
num_examples=min(limit, v.num_examples) * len(cfg.prompts.datasets),
dataset_name=v.dataset_name,
)
for k, v in base_splits.items()
Expand All @@ -239,13 +253,14 @@ def get_splits() -> SplitDict:

model_cfg = AutoConfig.from_pretrained(cfg.model)
num_variants = cfg.prompts.num_variants
ds_name, _, config_name = cfg.prompts.dataset.partition(" ")
ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)

features = assert_type(Features, info.features)
label_col = cfg.prompts.label_column or infer_label_column(features)

splits = get_splits()
print("SPLITS: ", splits)

layer_cols = {
f"hidden_{layer}": Array3D(
Expand Down
78 changes: 67 additions & 11 deletions elk/extraction/prompt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from ..promptsource import DatasetTemplates
from ..utils import assert_type, compute_class_balance, infer_label_column, undersample
from dataclasses import dataclass
from datasets import DatasetDict, load_dataset
from datasets import (
DatasetDict,
IterableDataset,
Dataset,
load_dataset,
concatenate_datasets,
)
from numpy.typing import NDArray
from random import Random
from simple_parsing.helpers import field, Serializable
Expand Down Expand Up @@ -44,7 +50,8 @@ class PromptConfig(Serializable):
call to __getitem__. Use -1 to apply all available templates. Defaults to 1.
"""

dataset: str = field(positional=True)
datasets: list[str] = field(positional=True)
# dataset2: str = field(positional=True)
balance: bool = False
label_column: Optional[str] = None
max_examples: int = 0
Expand All @@ -53,6 +60,16 @@ class PromptConfig(Serializable):
num_variants: int = 1


def create_prompt_dataset(
cfg: PromptConfig,
rank: int = 0,
world_size: int = 1,
split: str = "validation",
dataset_index: int = 0, # which dataset in cfg.datasets to use
):
pass


class PromptDataset(TorchDataset):
"""Wrapper for a HuggingFace dataset which generates prompts with `promptsource`.

Expand All @@ -79,8 +96,12 @@ def __init__(
rank: int = 0,
world_size: int = 1,
split: str = "validation",
dataset_index: int = 0, # which dataset in cfg.datasets to use
):
ds_name, _, config_name = cfg.dataset.partition(" ")
# super.__init__(self)
norabelrose marked this conversation as resolved.
Show resolved Hide resolved

dataset = cfg.datasets[dataset_index]
ds_name, _, config_name = dataset.partition(" ")

self.num_shots = cfg.num_shots
self.prompter = DatasetTemplates(ds_name, config_name or None) # type: ignore
Expand All @@ -100,15 +121,16 @@ def __init__(
# instantiations of PromptDataset (unless you set the seed to something else).
# This allows you to just set split="train" and split="test" for any dataset
# and not worry about train-test leakage.
split_name, *others = ds_dict.keys()
if not others:
print("Creating a 75/25 train-test split...")

# Don't shuffle now because we're going to shuffle later
ds_dict = ds_dict[split_name].train_test_split(
seed=cfg.seed, shuffle=False, stratify_by_column=cfg.label_column
)
assert isinstance(ds_dict, DatasetDict)
# split_name, *others = ds_dict.keys()
# if not others:
# print("Creating a 75/25 train-test split...")

# # Don't shuffle now because we're going to shuffle later
# ds_dict = ds_dict[split_name].train_test_split(
# seed=cfg.seed, shuffle=False, stratify_by_column=cfg.label_column
# )
# assert isinstance(ds_dict, DatasetDict)
norabelrose marked this conversation as resolved.
Show resolved Hide resolved

# The 'active' split is the one that gets queried by __getitem__
self.active_split = ds_dict[split]
Expand Down Expand Up @@ -225,3 +247,37 @@ def num_classes(self) -> int:

# We piggyback on the ClassLabel feature type to get the number of classes
return self.active_split.features[self.label_column].num_classes


class Interleaved_Datasets(TorchDataset):
norabelrose marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
datasets: list[PromptDataset],
):
"""
Interleave several (PromptDataset) datasets into a single dataset,
alternating between the datasets.
Only samples as many datapoints from each dataset as the smallest dataset.
Args:
datasets (`List[PromptDataset]`):
List of datasets to interleave.
"""
self.datasets = datasets

if not datasets:
raise ValueError("Unable to interleave an empty list of datasets.")

lengths = [len(dset) for dset in datasets]
self.min_dataset_length = min(lengths)
self.num_datasets = len(datasets)

def __getitem__(self, index: int) -> list[Prompt]:
which_dataset = index % self.num_datasets
norabelrose marked this conversation as resolved.
Show resolved Hide resolved
return self.datasets[which_dataset][int(index / self.num_datasets)]
norabelrose marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self):
return (self[i] for i in range(len(self)))

def __len__(self):
"""Get the number of predicates in the dataset."""
return self.num_datasets * self.min_dataset_length
25 changes: 21 additions & 4 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import random
import torch
import torch.multiprocessing as mp
from typing import Union


@dataclass
Expand All @@ -42,20 +43,22 @@ class RunConfig(Serializable):
max_gpus: int = -1
normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly"
skip_baseline: bool = False
concatenate_layers: int = 0
norabelrose marked this conversation as resolved.
Show resolved Hide resolved
# if nonzero, appends the hidden states of the layer concatenate_layers before


def train_reporter(
cfg: RunConfig,
dataset: DatasetDict,
out_dir: Path,
layer: int,
layer: Union[int, list[int]],
norabelrose marked this conversation as resolved.
Show resolved Hide resolved
devices: list[str],
world_size: int = 1,
):
"""Train a single reporter on a single layer."""

# Reproducibility
seed = cfg.net.seed + layer
seed = cfg.net.seed + layer if isinstance(layer, int) else layer[0]
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
Expand All @@ -72,9 +75,17 @@ def train_reporter(
train_labels = cast(Tensor, train["label"])
val_labels = cast(Tensor, val["label"])

# concatenate hidden states across layers if multiple layers are inputted
if isinstance(layer, list):
train_hiddens = torch.cat([train[f"hidden_{lay}"] for lay in layer], dim=1)
val_hiddens = torch.cat([val[f"hidden_{lay}"] for lay in layer], dim=1)
else:
train_hiddens = train[f"hidden_{layer}"]
val_hiddens = val[f"hidden_{layer}"]

train_h, val_h = normalize(
int16_to_float32(assert_type(Tensor, train[f"hidden_{layer}"])),
int16_to_float32(assert_type(Tensor, val[f"hidden_{layer}"])),
int16_to_float32(assert_type(Tensor, train_hiddens)),
int16_to_float32(assert_type(Tensor, val_hiddens)),
method=cfg.normalization,
)
x0, x1 = train_h.unbind(dim=-2)
Expand Down Expand Up @@ -161,6 +172,12 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None):
for feat in ds["train"].features
if feat.startswith("hidden_")
]

# concatenate hidden states from a previous layer, if told to
if cfg.concatenate_layers > 0:
for i in range(cfg.concatenate_layers, len(layers)):
layers[i] = [layers[i], layers[i] - cfg.concatenate_layers]

# Train reporters for each layer in parallel
with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
fn = partial(
Expand Down