diff --git a/README.md b/README.md index 96d51ee1..f7165573 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,12 @@ The following command will evaluate the probe from the run naughty-northcutt on elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb ``` +The following runs `elicit` on the Cartesian product of the listed models and datasets, storing it in a special folder ELK_DIR/sweeps/. Moreover, `--add_pooled` adds an additional dataset that pools all of the datasets together. + +```bash +elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled +``` + ## Caching The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`. diff --git a/elk/__main__.py b/elk/__main__.py index faa044f7..a477e799 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -5,6 +5,7 @@ from simple_parsing import ArgumentParser from elk.evaluation.evaluate import Eval +from elk.training.sweep import Sweep from elk.training.train import Elicit @@ -12,14 +13,14 @@ class Command: """Some top-level command""" - command: Elicit | Eval + command: Elicit | Eval | Sweep def execute(self): return self.command.execute() def run(): - parser = ArgumentParser(add_help=False, add_config_path_arg=True) + parser = ArgumentParser(add_help=False) parser.add_arguments(Command, dest="run") args = parser.parse_args() run: Command = args.run diff --git a/elk/training/sweep.py b/elk/training/sweep.py new file mode 100644 index 00000000..7f2bef2a --- /dev/null +++ b/elk/training/sweep.py @@ -0,0 +1,60 @@ +from dataclasses import InitVar, dataclass + +from ..extraction import Extract, PromptConfig +from ..files import elk_reporter_dir, memorably_named_dir +from .train import Elicit + + +@dataclass +class Sweep: + models: list[str] + """List of Huggingface model strings to sweep over.""" + datasets: list[str] + """List of dataset strings to sweep over. Each dataset string can contain + multiple datasets, separated by plus signs. For example, "sst2+imdb" will + pool SST-2 and IMDB together.""" + add_pooled: InitVar[bool] = False + """Whether to add a dataset that pools all of the other datasets together.""" + + name: str | None = None + + def __post_init__(self, add_pooled: bool): + if not self.datasets: + raise ValueError("No datasets specified") + if not self.models: + raise ValueError("No models specified") + + # Add an additional dataset that pools all of the datasets together. + if add_pooled: + self.datasets.append("+".join(self.datasets)) + + def execute(self): + M, D = len(self.models), len(self.datasets) + print(f"Starting sweep over {M} models and {D} datasets ({M * D} runs)") + print(f"Models: {self.models}") + print(f"Datasets: {self.datasets}") + + root_dir = elk_reporter_dir() / "sweeps" + sweep_dir = root_dir / self.name if self.name else memorably_named_dir(root_dir) + print(f"Saving sweep results to \033[1m{sweep_dir}\033[0m") # bold + + for i, model_str in enumerate(self.models): + # Magenta color for the model name + print(f"\n\033[35m===== {model_str} ({i + 1} of {M}) =====\033[0m") + + for dataset_str in self.datasets: + out_dir = sweep_dir / model_str / dataset_str + + # Allow for multiple datasets to be specified in a single string with + # plus signs. This means we can pool datasets together inside of a + # single sweep. + datasets = [ds.strip() for ds in dataset_str.split("+")] + Elicit( + data=Extract( + model=model_str, + prompts=PromptConfig( + datasets=datasets, + ), + ), + out_dir=out_dir, + ).execute()