Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi datasets #123

Merged
merged 46 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
681698d
add multiple datasets support
Benw8888 Mar 9, 2023
ac1b9f1
Merge branch 'main' of github.com:EleutherAI/elk into multi-datasets
Benw8888 Mar 9, 2023
b864c77
train_reporter works on a list of layers now
Benw8888 Mar 10, 2023
7d7d97c
changing printed layer names
Benw8888 Mar 10, 2023
4fe61e9
fixed concatenation bug
Benw8888 Mar 11, 2023
fe61d67
minor edits
Benw8888 Mar 13, 2023
74da878
fixed pyright issues
Benw8888 Mar 13, 2023
569ef05
Merge branch 'main' of github.com:EleutherAI/elk into multi-datasets
Benw8888 Mar 14, 2023
b62b679
Merge branch 'main' into multi-datasets
norabelrose Mar 20, 2023
fe94c22
Fix tests
norabelrose Mar 20, 2023
bba24d8
Now working sorta
norabelrose Mar 22, 2023
03ba6e0
Skip slow BalancedBatchSampler test
norabelrose Mar 22, 2023
15ab351
Slightly relax test_output_is_roughly_balanced
norabelrose Mar 22, 2023
a80369e
Make BalancedSampler deterministic
norabelrose Mar 22, 2023
d304ab3
InitVar
norabelrose Mar 22, 2023
761c82d
Support multi class again
norabelrose Mar 22, 2023
f29743b
Fix naming issue
norabelrose Mar 22, 2023
b7b7e23
Support few shot prompts
norabelrose Mar 23, 2023
1afb563
Merge branch 'main' into multi-datasets
norabelrose Mar 23, 2023
225d4c7
fix multiclass labels
AlexTMallen Mar 23, 2023
9368dc8
Merge branch 'multi-datasets' of github.com:EleutherAI/elk into multi…
AlexTMallen Mar 23, 2023
a858b65
Merge branch 'main' into multi-datasets
norabelrose Mar 24, 2023
5dc2ec6
Merge branch 'multi-datasets' of github.com:EleutherAI/elk into multi…
norabelrose Mar 24, 2023
b1b95e5
Fix dumb part of test failures
norabelrose Mar 25, 2023
ee3911e
Fix assert_allclose warning
norabelrose Mar 25, 2023
a55b3de
Switch to torch.testing.assert_close in EigenReporter test
norabelrose Mar 25, 2023
44dc25c
Shuffle load_prompts output by default
norabelrose Mar 25, 2023
93d8d87
Fix smoke test failure
norabelrose Mar 25, 2023
fad4d74
Remove debug prints
AlexTMallen Mar 25, 2023
0a054f4
Remove more debug print statements
AlexTMallen Mar 25, 2023
177eec2
make min_memory usable; broadcast mmax_examples in __post_init__
AlexTMallen Mar 26, 2023
3a762b0
prompt loading refactor to enable better streaming
AlexTMallen Mar 26, 2023
f66c054
remove shuffle arg
AlexTMallen Mar 26, 2023
d3d87fc
remove unused @dataclass
lauritowal Mar 26, 2023
3d08147
merge
lauritowal Mar 27, 2023
c9a43e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 27, 2023
94290aa
add concatenated_layer_offset to eval
lauritowal Mar 27, 2023
f9298e4
Merge branch 'multi-datasets' of https://github.com/EleutherAI/elk in…
lauritowal Mar 27, 2023
3765c4f
add self.
lauritowal Mar 27, 2023
2b05193
replace target with data
lauritowal Mar 27, 2023
83731bb
add self.
lauritowal Mar 27, 2023
764fda9
remove second arg
lauritowal Mar 27, 2023
d2c66b0
fix passing the wrong params for world size / rank
thejaminator Mar 28, 2023
9186326
Update prompt_loading.py
lauritowal Mar 28, 2023
3f99a4d
fix pre-commit errors
lauritowal Mar 28, 2023
148130d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

Because language models are trained to predict the next token in naturally occurring text, they often reproduce common human errors and misconceptions, even when they "know better" in some sense. More worryingly, when models are trained to generate text that's rated highly by humans, they may learn to output false statements that human evaluators can't detect. We aim to circumvent this issue by directly [**eliciting latent knowledge**](https://docs.google.com/document/d/1WwsnJQstPq91_Yh-Ch2XRL8H_EpsnjrC1dwZXR37PC8/edit) (ELK) inside the activations of a language model.

Specifically, we're building on the **Contrast Consistent Search** (CCS) method described in the paper [Discovering Latent Knowledge in Language Models Without Supervision](https://arxiv.org/abs/2212.03827) by Burns et al. (2022). In CCS, we search for features in the hidden states of a language model which satisfy certain logical consistency requirements. It turns out that these features are often useful for question-answering and text classification tasks, even though the features are trained without labels.
Specifically, we're building on the **Contrastive Representation Clustering** (CRC) method described in the paper [Discovering Latent Knowledge in Language Models Without Supervision](https://arxiv.org/abs/2212.03827) by Burns et al. (2022). In CRC, we search for features in the hidden states of a language model which satisfy certain logical consistency requirements. It turns out that these features are often useful for question-answering and text classification tasks, even though the features are trained without labels.

### Quick **Start**

Expand All @@ -20,29 +20,21 @@ elk elicit microsoft/deberta-v2-xxlarge-mnli imdb

This will automatically download the model and dataset, run the model and extract the relevant representations if they aren't cached on disk, fit reporters on them, and save the reporter checkpoints to the `elk-reporters` folder in your home directory. It will also evaluate the reporter classification performance on a held out test set and save it to a CSV file in the same folder.

The following will generate a CCS (Contrast Consistent Search) reporter instead of the CRC-based reporter, which is the default.

```bash
elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb
elk elicit microsoft/deberta-v2-xxlarge-mnli imdb --net ccs
```

This will evaluate the probe from the run naughty-northcutt on the hidden states extracted from the model deberta-v2-xxlarge-mnli for the imdb dataset. It will result in an `eval.csv` and `cfg.yaml` file, which are stored under a subfolder in `elk-reporters/naughty-northcutt/transfer_eval`.

## Caching

The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`.

## Other commands

To only extract the hidden states for the model `model` and the dataset `dataset` and save them to `my_output_dir`, without training any reporters, you can run:
The following command will evaluate the probe from the run naughty-northcutt on the hidden states extracted from the model deberta-v2-xxlarge-mnli for the imdb dataset. It will result in an `eval.csv` and `cfg.yaml` file, which are stored under a subfolder in `elk-reporters/naughty-northcutt/transfer_eval`.

```bash
elk extract microsoft/deberta-v2-xxlarge-mnli imdb -o my_output_dir
elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb
```

The following will generate a CCS reporter instead of the Eigen reporter, which is the default.
## Caching

```bash
elk elicit microsoft/deberta-v2-xxlarge-mnli imdb --net ccs
```
The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`.

## Development
Use `pip install pre-commit && pre-commit install` in the root folder before your first commit.
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .balanced_sampler import BalancedBatchSampler, BalancedSampler
from .balanced_sampler import BalancedSampler, FewShotSampler
from .extraction import ExtractionConfig, extract_hiddens, extract
from .generator import _GeneratorConfig, _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts
51 changes: 31 additions & 20 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from ..utils import infer_label_column
from collections import Counter
from ..math_util import stochastic_round_constrained
from dataclasses import dataclass, field, InitVar
from datasets import IterableDataset
from itertools import cycle
from random import Random
from torch.utils.data import IterableDataset as TorchIterableDataset
from typing import Iterator, Optional
import numpy as np
Expand Down Expand Up @@ -50,39 +51,49 @@ def __iter__(self):
yield sample


class BalancedBatchSampler:
"""Yields precisely balanced batches from a binary classification dataset.
class FewShotSampler:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why you used @DataClass for the above class, but not here?

Copy link
Collaborator

@lauritowal lauritowal Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't had time yet to review the pull request in more detail and to test it. I can do that tomorrow evening, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is a little bit inconsistent but idk how much it matters

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably FewShotSampler might be better if we make it not an iterator bc infinite iterators are weird

"""Yields batches of few-shot examples that are as balanced as possible.

Written by a human being because GPT-4 couldn't figure out how to do it.
If the number of examples is divisible by the number of shots, this sampler
will yield batches of exactly `num_shots` examples. Otherwise, it will
use `stochastic_round_constrained` to get as close to balanced batches as
possible.
"""

def __init__(
self,
dataset: IterableDataset,
num_shots: int,
rng: Random,
label_col: Optional[str] = None,
batch_size: int = 32,
):
self.batch_size = batch_size
self.dataset = dataset
self.label_col = label_col or infer_label_column(dataset.features)
self.num_shots = num_shots
self.rng = rng

def __iter__(self) -> Iterator[list[dict]]:
batch = []
neg_buf, pos_buf = [], []

max_count = self.batch_size // 2
label_counts = Counter()

# Infinite loop!
# Infinite loop over the dataset!
for sample in cycle(self.dataset):
label = sample[self.label_col]
if label_counts[label] >= max_count:
continue

batch.append(sample)
label_counts[label] += 1
if label == 0:
neg_buf.append(sample)
elif label == 1:
pos_buf.append(sample)
else:
raise ValueError(f"Expected label to be 0 or 1, got {label}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't support few-shot examples for dbpedia or ag_news (or multiclass datasets)? I guess this is fine until someone wants to do this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't want to think about how to support that rn


neg_count, pos_count = stochastic_round_constrained(
[self.num_shots / 2, self.num_shots / 2], self.rng
)
while len(neg_buf) >= neg_count and len(pos_buf) >= pos_count:
batch = []
for _ in range(neg_count):
batch.append(neg_buf.pop())
for _ in range(pos_count):
batch.append(pos_buf.pop())

if len(batch) == self.batch_size:
self.rng.shuffle(batch)
yield batch

batch = []
label_counts.clear()
40 changes: 34 additions & 6 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ..math_util import stochastic_round_constrained
from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
Expand All @@ -7,6 +6,7 @@
infer_num_classes,
select_train_val_splits,
)
from .balanced_sampler import FewShotSampler
from dataclasses import dataclass
from datasets import (
interleave_datasets,
Expand All @@ -19,7 +19,7 @@
from datasets.distributed import split_dataset_by_node
from random import Random
from simple_parsing.helpers import field, Serializable
from typing import Any, Literal, Optional
from typing import Any, Iterator, Literal, Optional


@dataclass
Expand Down Expand Up @@ -65,6 +65,7 @@ def __post_init__(self):
def load_prompts(
*dataset_strings: str,
max_examples: int = 0,
num_shots: int = 0,
seed: int = 42,
split_type: Literal["train", "val"] = "train",
rank: int = 0,
Expand All @@ -76,6 +77,8 @@ def load_prompts(
dataset_strings: Space-delimited names of the HuggingFace datasets to use,
e.g. `"super_glue boolq"` or `"imdb"`.
max_examples: The maximum number of examples to use from the dataset.
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot.
seed: The seed to use for prompt randomization.
split_type: Whether to use the train or val split of the dataset.
rank: The rank of the current process. Defaults to 0.
Expand All @@ -87,6 +90,7 @@ def load_prompts(
prompt_datasets = []
prompters = []
raw_datasets = []
train_datasets = []
rng = Random(seed)

# First load the datasets and prompters. We need to know the minimum number of
Expand All @@ -101,9 +105,10 @@ def load_prompts(
train_name, val_name = select_train_val_splits(ds_dict)
split_name = val_name if split_type == "val" else train_name
raw_datasets.append(assert_type(IterableDataset, ds_dict[split_name]))
train_datasets.append(assert_type(IterableDataset, ds_dict[train_name]))

num_variants = min(len(prompter.templates) for prompter in prompters)
for ds, prompter in zip(raw_datasets, prompters):
for ds, train_ds, prompter in zip(raw_datasets, train_datasets, prompters):
label_column = infer_label_column(ds.features)
num_classes = infer_num_classes(ds.features[label_column])

Expand All @@ -113,6 +118,15 @@ def load_prompts(

if label_column != "label":
ds = ds.rename_column(label_column, "label")
if num_shots > 0:
fewshot = FewShotSampler(
train_ds,
num_shots=num_shots,
rng=rng,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the function signature is

def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType:

so we were passing world_size to rank and vice versa

fewshot_iter = iter(fewshot)
else:
fewshot_iter = None

# Canonicalize the name and dtype of the label column
ds = ds.map(
Expand All @@ -123,6 +137,7 @@ def load_prompts(
num_variants=num_variants,
prompter=prompter,
rng=rng,
fewshot_iter=fewshot_iter,
),
remove_columns=extra_cols,
).map(
Expand Down Expand Up @@ -155,6 +170,7 @@ def load_prompts(
if max_examples > 0:
master_ds = master_ds.take(max_examples)
if world_size > 1:
# This prints to stdout which is slightly annoying
master_ds = split_dataset_by_node(master_ds, rank, world_size)

return master_ds
Expand All @@ -167,13 +183,19 @@ def _convert_to_prompts(
num_classes: int,
num_variants: int,
rng: Random,
fewshot_iter: Optional[Iterator[list[dict]]] = None,
) -> dict[str, Any]:
"""Prompt-generating function to pass to `IterableDataset.map`."""
prompts = []
templates = list(prompter.templates.values())
if num_variants < len(templates):
templates = rng.sample(templates, num_variants)

def qa_cat(q: str, a: str) -> str:
# if the jinja template already adds whitespace, don't add more
sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " "
return f"{q}{sep}{a}" if a and not a.isspace() else q

new_label = rng.choice([0, 1]) if num_classes > 2 else example[label_column]

for template in templates:
Expand All @@ -189,10 +211,16 @@ def _convert_to_prompts(
fake_example[label_column] = answer_idx

q, a = template.apply(fake_example)
text = qa_cat(q, a)

if fewshot_iter is not None:
# Infinite iterator so we don't need to worry about StopIteration
fewshot_examples = next(fewshot_iter)
fewshot_texts = [
qa_cat(q, a) for q, a in map(template.apply, fewshot_examples)
]
text = "\n\n".join(fewshot_texts) + "\n\n" + text

# if the jinja template already adds whitespace, don't add more
sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " "
text = f"{q}{sep}{a}" if a and not a.isspace() else q
choices.append(
dict(
# Strip whitespace from the answer to make it easier to
Expand Down
28 changes: 14 additions & 14 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from collections import Counter
from datasets import load_dataset, IterableDataset
from elk.extraction import BalancedBatchSampler, BalancedSampler
from elk.extraction import FewShotSampler, BalancedSampler
from elk.utils import assert_type, infer_label_column
from itertools import islice
import pytest
from random import Random


@pytest.mark.skip(reason="This test is too slow to run on every commit")
def test_output_batches_are_balanced():
# Load an example dataset for testing
dataset = assert_type(
Expand All @@ -15,20 +14,21 @@ def test_output_batches_are_balanced():
)
label_col = infer_label_column(dataset.features)

# Create the BalancedBatchSampler instance
batch_size = 32
balanced_batch_sampler = BalancedBatchSampler(dataset, batch_size=batch_size)

# Iterate through batches and check if they are balanced
for batch in balanced_batch_sampler:
# Start with an even number of shots; make sure they're exactly balanced
sampler = FewShotSampler(dataset, 6, rng=Random(42))
for batch in islice(sampler, 5):
counter = Counter(sample[label_col] for sample in batch)

# Check if the output batch is balanced
label_0_count = counter[0]
label_1_count = counter[1]
assert (
label_0_count == label_1_count
), f"Batch is not balanced: {label_0_count}, {label_1_count}"
assert counter[0] == counter[1]

# Start with an odd number of shots; make sure they're roughly balanced
sampler = FewShotSampler(dataset, 5, rng=Random(42))
for batch in islice(sampler, 5):
counter = Counter(sample[label_col] for sample in batch)

# The batch should be balanced to within 1 sample
assert abs(counter[0] - counter[1]) <= 1


def test_output_is_roughly_balanced():
Expand Down