Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binarized meta-templates; some extraction refactoring #218

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bbee489
Initial support for FEVER
norabelrose Apr 22, 2023
5ba1ddd
Start saving and fitting a reporter to the input embeddings
norabelrose Apr 22, 2023
3b1f74d
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
51ba54f
Rename layer 0 to 'input' to make it more clear
norabelrose Apr 22, 2023
544b485
Actually rename layer 0 correctly
norabelrose Apr 22, 2023
43da44e
Handle layer_stride correctly
norabelrose Apr 22, 2023
9056e00
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
756fa53
label_choices
norabelrose Apr 22, 2023
93b7ae0
Clean up train and eval commands; do transfer in sweep
norabelrose Apr 22, 2023
57d0b8b
Support INLP and split eval output into multiple CSVs
norabelrose Apr 22, 2023
228a6a0
Merge branch 'inlp' into template-filtering
norabelrose Apr 22, 2023
b086f0b
Merge branch 'inlp' into template-filtering
norabelrose Apr 25, 2023
934cd54
Log ensembled metrics
norabelrose Apr 26, 2023
dff69bf
Fixing pyright version
norabelrose Apr 26, 2023
b181d3e
Merge remote-tracking branch 'origin/main' into ensembling
norabelrose Apr 26, 2023
15254bf
Merge main
norabelrose Apr 26, 2023
69c2d55
Tons of stuff, preparing for sciq_binary experiment
norabelrose Apr 27, 2023
960ff01
Support --binarize again
norabelrose Apr 27, 2023
c9e62ea
Partial support for truthful_qa
norabelrose Apr 27, 2023
eb71a6c
Merge branch 'main' into template-filtering
norabelrose Apr 29, 2023
88bb15e
Merge remote-tracking branch 'origin/main' into template-filtering
norabelrose Apr 29, 2023
c648ff0
Remove crap
norabelrose Apr 29, 2023
ef12130
EleutherAI/truthful_qa_mc
norabelrose Apr 29, 2023
5d60ebd
Update templates
norabelrose Apr 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Initial support for FEVER
  • Loading branch information
norabelrose committed Apr 22, 2023
commit bbee4890ea180c0e9aedcf9f27fd9f15bcb4a377
26 changes: 15 additions & 11 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections import deque
from dataclasses import dataclass, field
from dataclasses import InitVar, dataclass, field
from itertools import cycle
from random import Random
from typing import Iterable, Iterator, Optional
from typing import Hashable, Iterable, Iterator, Optional

from datasets import Features, IterableDataset
from torch.utils.data import IterableDataset as TorchIterableDataset
Expand All @@ -26,25 +26,29 @@ class BalancedSampler(TorchIterableDataset):
"""

data: Iterable[dict]
num_classes: int
label_choices: InitVar[set[Hashable]]
buffer_size: int = 1000
buffers: dict[int, deque[dict]] = field(default_factory=dict, init=False)
buffers: dict[Hashable, deque[dict]] = field(default_factory=dict, init=False)
label_col: str = "label"
strict: bool = True

def __post_init__(self):
def __post_init__(self, label_choices: set[Hashable]):
# Initialize empty buffers
self.buffers = {
label: deque(maxlen=self.buffer_size) for label in range(self.num_classes)
label: deque(maxlen=self.buffer_size) for label in label_choices
}

def __iter__(self):
for sample in self.data:
label = sample[self.label_col]

# This whole class is a no-op if the label is not an integer
if not isinstance(label, int):
yield sample
continue
if label not in self.buffers:
if self.strict:
raise ValueError(
f"Expected label to be one of {self.buffers}, got {label}"
)
else:
# Just skip this sample
continue

# Add the sample to the buffer for its class label
self.buffers[label].append(sample)
Expand Down
24 changes: 10 additions & 14 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
assert_type,
float32_to_int16,
infer_label_column,
infer_num_classes,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
Expand Down Expand Up @@ -128,10 +127,7 @@ def extract_hiddens(

prompt_ds = load_prompts(
ds_names[0],
label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None,
num_classes=p_cfg.num_classes,
split_type=split_type,
stream=p_cfg.stream,
rank=rank,
world_size=world_size,
)
Expand Down Expand Up @@ -202,15 +198,15 @@ def extract_hiddens(

input_ids = torch.cat([input_ids, answer], dim=-1)
if max_len := tokenizer.model_max_length:
input_ids = input_ids[..., -max_len:]
cur_len = input_ids.shape[-1]
input_ids = input_ids[..., -min(max_len, cur_len) :]

# Make sure we only pass the arguments that the model expects
inputs = dict(input_ids=input_ids)
if is_enc_dec:
inputs["labels"] = answer

with torch.autocast("cuda", enabled=torch.cuda.is_available()):
outputs = model(**inputs, output_hidden_states=True)
outputs = model(**inputs, output_hidden_states=True)

# Compute the log probability of the answer tokens if available
if has_lm_preds:
Expand Down Expand Up @@ -303,17 +299,17 @@ def get_splits() -> SplitDict:
ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)

prompter = DatasetTemplates(ds_name, config_name)
ds_features = assert_type(Features, info.features)
label_col = (
cfg.prompts.label_columns[0]
if cfg.prompts.label_columns
else infer_label_column(ds_features)
)
num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col])
prompter.label_column or infer_label_column(ds_features)
num_classes = 2 # prompter.num_classes or infer_num_classes(ds_features[label_col])

num_variants = cfg.prompts.num_variants
if num_variants < 0:
prompter = DatasetTemplates(ds_name, config_name)
num_dropped = prompter.drop_non_mc_templates()
num_variants = len(prompter.templates)
if num_dropped:
print(f"Dropping {num_dropped} non-multiple choice templates")

layer_cols = {
f"hidden_{layer}": Array3D(
Expand Down
138 changes: 62 additions & 76 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,15 @@
from dataclasses import dataclass
from itertools import zip_longest
from random import Random
from typing import Any, Iterator, Literal, Optional
from typing import Any, Iterator, Literal

from datasets import (
Dataset,
Features,
load_dataset,
)
from datasets.distributed import split_dataset_by_node
from datasets import ClassLabel, Dataset, Value, load_dataset
from simple_parsing.helpers import Serializable, field

from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
infer_label_column,
infer_num_classes,
select_train_val_splits,
)
from .balanced_sampler import BalancedSampler, FewShotSampler
Expand All @@ -31,9 +25,6 @@ class PromptConfig(Serializable):
`"super_glue boolq"` or `"imdb"`.
data_dir: The directory to use for caching the dataset. Defaults to
`~/.cache/huggingface/datasets`.
label_column: The column containing the labels. By default, we infer this from
the datatypes of the columns in the dataset; if there is only one column
with a `ClassLabel` datatype, we use that.
max_examples: The maximum number of examples to use from the val dataset.
If a single number, use at most that many examples for each split. If a list
of length 2, use the first element for the train split and the second for
Expand All @@ -43,18 +34,14 @@ class PromptConfig(Serializable):
num_variants: The number of prompt templates to apply to each predicate upon
call to __getitem__. Use -1 to apply all available templates. Defaults to 1.
seed: The seed to use for prompt randomization. Defaults to 42.
stream: Whether to stream the dataset from the Internet. Defaults to False.
"""

datasets: list[str] = field(positional=True)
data_dirs: list[str] = field(default_factory=list)
label_columns: list[str] = field(default_factory=list)
max_examples: list[int] = field(default_factory=lambda: [1000, 1000])
num_classes: int = 0
num_shots: int = 0
num_variants: int = -1
seed: int = 42
stream: bool = False

def __post_init__(self):
if len(self.max_examples) > 2:
Expand All @@ -69,7 +56,7 @@ def __post_init__(self):
if len(self.max_examples) == 1:
self.max_examples *= 2

# Broadcast the dataset name to all data_dirs and label_columns
# Broadcast the dataset name to all data_dirs
if len(self.data_dirs) == 1:
self.data_dirs *= len(self.datasets)
elif self.data_dirs and len(self.data_dirs) != len(self.datasets):
Expand All @@ -78,25 +65,14 @@ def __post_init__(self):
f" but got {len(self.data_dirs)}"
)

if len(self.label_columns) == 1:
self.label_columns *= len(self.datasets)
elif self.label_columns and len(self.label_columns) != len(self.datasets):
raise ValueError(
"label_columns should be a list of length 0, 1, or len(datasets),"
f" but got {len(self.label_columns)}"
)

def explode(self) -> list["PromptConfig"]:
"""Explode the config into a list of configs, one for each dataset."""
copies = []

for ds, data_dir, col in zip_longest(
self.datasets, self.data_dirs, self.label_columns
):
for ds, data_dir in zip_longest(self.datasets, self.data_dirs):
copy = deepcopy(self)
copy.datasets = [ds]
copy.data_dirs = [data_dir] if data_dir else []
copy.label_columns = [col] if col else []
copies.append(copy)

return copies
Expand All @@ -105,13 +81,10 @@ def explode(self) -> list["PromptConfig"]:
def load_prompts(
ds_string: str,
*,
label_column: Optional[str] = None,
num_classes: int = 0,
num_shots: int = 0,
num_variants: int = -1,
seed: int = 42,
split_type: Literal["train", "val"] = "train",
stream: bool = False,
rank: int = 0,
world_size: int = 1,
) -> Iterator[dict]:
Expand All @@ -120,15 +93,10 @@ def load_prompts(
Args:
ds_string: Space-delimited name of the HuggingFace dataset to use,
e.g. `"super_glue boolq"` or `"imdb"`.
label_column: The column containing the labels. By default, we infer this from
the datatypes of the columns in the dataset.
num_classes: The number of classes in the dataset. If zero, we infer this from
the datatypes of the columns in the dataset.
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot.
seed: The seed to use for prompt randomization.
split_type: Whether to use the train or val split of the dataset.
stream: Whether to stream the dataset from the Internet. Defaults to False.
rank: The rank of the current process. Defaults to 0.
world_size: The number of processes. Defaults to 1.

Expand All @@ -137,25 +105,18 @@ def load_prompts(
"""
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)
prompter.drop_non_mc_templates()

ds_dict = assert_type(
dict, load_dataset(ds_name, config_name or None, streaming=stream)
)
ds_dict = assert_type(dict, load_dataset(ds_name, config_name or None))
train_name, val_name = select_train_val_splits(ds_dict)
split_name = val_name if split_type == "val" else train_name

ds = ds_dict[split_name].shuffle(seed=seed)
train_ds = ds_dict[train_name].shuffle(seed=seed)
if not stream:
ds = assert_type(Dataset, ds)
if world_size > 1:
ds = ds.shard(world_size, rank)

ds = ds.to_iterable_dataset().cast(ds.features)

elif world_size > 1:
# This prints to stdout which is slightly annoying
ds = split_dataset_by_node(dataset=ds, rank=rank, world_size=world_size)
ds = assert_type(Dataset, ds)
if world_size > 1:
ds = ds.shard(world_size, rank)

num_templates = len(prompter.templates)
num_variants = (
Expand All @@ -165,10 +126,44 @@ def load_prompts(
if rank == 0:
print(f"Using {num_variants} variants of each prompt")

label_column = label_column or infer_label_column(ds.features)
num_classes = num_classes or infer_num_classes(ds.features[label_column])
rng = Random(seed)
# Which classes are actually present in this split of the dataset?
# This is shockingly fast since it uses an optimized Apache Arrow primitive.
label_column = prompter.label_column or infer_label_column(ds.features)
observed_labels = set(ds.unique(label_column))

# Now sanity check that the observed classes match the expected classes. This can
# sometimes fail if we picked an unlabeled split (e.g. everything is -1)
label_feature = ds.features[label_column]
if isinstance(label_feature, ClassLabel):
label_choices = {label_feature.str2int(label) for label in label_feature.names}
elif isinstance(label_feature, Value) and label_feature.dtype == "bool":
label_choices = {False, True}
else:
# We just have to assume that the observed labels are right
label_choices = observed_labels

if observed_labels != label_choices:
raise ValueError(
f"Observed labels {observed_labels} in split '{split_name}' do not match "
f"expected labels {label_choices} from the dataset features."
)

if prompt_choices := prompter.label_choices:
# The observed labels should be a superset of the prompt choices
if not (observed_labels >= set(prompt_choices)):
raise ValueError(
f"Observed labels {observed_labels} in split '{split_name}' do not "
f"match the prompt choices {prompt_choices}."
)

sorted_labels = prompt_choices
else:
# Impose a canonical order on the label choices. Theoretically the label column
# may be of a type that doesn't support comparison (so Pylance complains), but
# we'll just let it raise an exception if that happens.
sorted_labels = sorted(label_choices) # type: ignore[arg-type]

rng = Random(seed)
if num_shots > 0:
fewshot = FewShotSampler(
train_ds, # TODO: not iterator
Expand All @@ -179,15 +174,17 @@ def load_prompts(
else:
fewshot_iter = None

# Remove everything except the label column
extra_cols = list(assert_type(Features, ds.features))
extra_cols.remove(label_column)
ds = ds.to_iterable_dataset()
if rank == 0:
print(f"Label choices: {sorted_labels}")

for example in BalancedSampler(ds, num_classes, label_col=label_column):
for example in BalancedSampler(
ds, set(sorted_labels), label_col=label_column, strict=False
):
yield _convert_to_prompts(
example,
label_column=label_column,
num_classes=num_classes,
label_choices=sorted_labels, # type: ignore[arg-type]
num_variants=num_variants,
prompter=prompter,
rng=rng,
Expand All @@ -199,13 +196,12 @@ def _convert_to_prompts(
example: dict[str, Any],
prompter: DatasetTemplates,
label_column: str,
num_classes: int,
label_choices: list[bool | int | str],
num_variants: int,
rng: Random,
fewshot_iter: Optional[Iterator[list[dict]]] = None,
fewshot_iter: Iterator[list[dict]] | None = None,
) -> dict[str, Any]:
"""Prompt-generating function to pass to `IterableDataset.map`."""
labels_are_strings = isinstance(example[label_column], str)
prompts = []
templates = list(prompter.templates.values())
if num_variants < len(templates):
Expand All @@ -218,21 +214,14 @@ def qa_cat(q: str, a: str) -> str:

# For sanity checking that prompts are unique
prompt_counter = Counter()
label_indices = set()
label = example[label_column]

for template in templates:
choices = []
string_choices = template.get_answer_choices_list(example)

label = example[label_column]
label_indices.add(string_choices.index(label) if labels_are_strings else label)

for answer_idx in range(num_classes):
for pseudo_label in label_choices:
fake_example = example.copy()
if labels_are_strings:
fake_example[label_column] = string_choices[answer_idx]
else:
fake_example[label_column] = answer_idx
fake_example[label_column] = pseudo_label

q, a = template.apply(fake_example)
prompt_counter[(q, a)] += 1
Expand Down Expand Up @@ -261,14 +250,11 @@ def qa_cat(q: str, a: str) -> str:
if dup_count > 1:
raise ValueError(f'Prompt duplicated {dup_count} times! "{maybe_dup}"')

# Sanity check: label should be the same across all variants
if len(label_indices) > 1:
raise ValueError(
f"Label index should be the same all variants, but got {label_indices}"
)

# Our reporter training and evaluation code assumes that the labels are integers.
# If they're not, we need to convert them with index(). label_choices is guaranteed
# to be sorted (see above).
return dict(
label=label_indices.pop(),
label=label_choices.index(label),
prompts=prompts,
template_names=[template.name for template in templates],
)
Loading