Skip to content

Commit

Permalink
Sweep MVP (EleutherAI#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Apr 16, 2023
1 parent 16dc1ca commit 691d314
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ The following command will evaluate the probe from the run naughty-northcutt on
elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb
```

The following runs `elicit` on the Cartesian product of the listed models and datasets, storing it in a special folder ELK_DIR/sweeps/<memorable_name>. Moreover, `--add_pooled` adds an additional dataset that pools all of the datasets together.

```bash
elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled
```

## Caching

The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`.
Expand Down
5 changes: 3 additions & 2 deletions elk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@
from simple_parsing import ArgumentParser

from elk.evaluation.evaluate import Eval
from elk.training.sweep import Sweep
from elk.training.train import Elicit


@dataclass
class Command:
"""Some top-level command"""

command: Elicit | Eval
command: Elicit | Eval | Sweep

def execute(self):
return self.command.execute()


def run():
parser = ArgumentParser(add_help=False, add_config_path_arg=True)
parser = ArgumentParser(add_help=False)
parser.add_arguments(Command, dest="run")
args = parser.parse_args()
run: Command = args.run
Expand Down
60 changes: 60 additions & 0 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from dataclasses import InitVar, dataclass

from ..extraction import Extract, PromptConfig
from ..files import elk_reporter_dir, memorably_named_dir
from .train import Elicit


@dataclass
class Sweep:
models: list[str]
"""List of Huggingface model strings to sweep over."""
datasets: list[str]
"""List of dataset strings to sweep over. Each dataset string can contain
multiple datasets, separated by plus signs. For example, "sst2+imdb" will
pool SST-2 and IMDB together."""
add_pooled: InitVar[bool] = False
"""Whether to add a dataset that pools all of the other datasets together."""

name: str | None = None

def __post_init__(self, add_pooled: bool):
if not self.datasets:
raise ValueError("No datasets specified")
if not self.models:
raise ValueError("No models specified")

# Add an additional dataset that pools all of the datasets together.
if add_pooled:
self.datasets.append("+".join(self.datasets))

def execute(self):
M, D = len(self.models), len(self.datasets)
print(f"Starting sweep over {M} models and {D} datasets ({M * D} runs)")
print(f"Models: {self.models}")
print(f"Datasets: {self.datasets}")

root_dir = elk_reporter_dir() / "sweeps"
sweep_dir = root_dir / self.name if self.name else memorably_named_dir(root_dir)
print(f"Saving sweep results to \033[1m{sweep_dir}\033[0m") # bold

for i, model_str in enumerate(self.models):
# Magenta color for the model name
print(f"\n\033[35m===== {model_str} ({i + 1} of {M}) =====\033[0m")

for dataset_str in self.datasets:
out_dir = sweep_dir / model_str / dataset_str

# Allow for multiple datasets to be specified in a single string with
# plus signs. This means we can pool datasets together inside of a
# single sweep.
datasets = [ds.strip() for ds in dataset_str.split("+")]
Elicit(
data=Extract(
model=model_str,
prompts=PromptConfig(
datasets=datasets,
),
),
out_dir=out_dir,
).execute()

0 comments on commit 691d314

Please sign in to comment.