Skip to content

Commit

Permalink
add elicit option, replace trainnamed + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lauritowal committed Feb 10, 2023
1 parent 03c9145 commit 0d2b3f5
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 25 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,20 @@ Our code is based on [PyTorch](http:https://pytorch.org) and [Huggingface Transformers

First install the package with `pip install -e .` in the root directory, or `pip install -e .[dev]` if you'd like to contribute to the project (see **Development** section below). This should install all the necessary dependencies.

1. To extract the hidden states for one model `model` and the dataset `dataset`, run:

To extract the hidden states for one model `model` and the dataset `dataset` *and* train a probe on these extracted hidden states, run:

```bash
elk elicit microsoft/deberta-v2-xxlarge-mnli imdb --max-examples 1000
```

To only extract the hidden states for one model `model` and the dataset `dataset`, run:

```bash
elk extract microsoft/deberta-v2-xxlarge-mnli imdb --max-examples 1000
```

1. To train a CCS model and a logistic regression model
To only train a CCS model and a logistic regression model

```bash
elk train microsoft/deberta-v2-xxlarge-mnli imdb
Expand Down
38 changes: 26 additions & 12 deletions elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from elk.files import args_to_uuid
from elk.files import args_to_uuid, elk_cache_dir
from .extraction.extraction_main import run as run_extraction
from .extraction.parser import get_extraction_parser, get_saveable_args
from .training.parser import get_training_parser, get_named_training_parser
from .extraction.parser import (
add_saveable_args,
add_unsaveable_args,
get_extraction_parser,
)
from .training.parser import add_train_args, get_training_parser
from .training.train import train
from argparse import ArgumentParser
from pathlib import Path
Expand All @@ -22,18 +26,20 @@ def run():
"train",
help=(
"Train a set of ELK probes on hidden states from `elk extract`. "
"The first arguments have to be the arguments from `elk extract`."
"The first argument has to be the name you gave to the extraction."
),
parents=[get_training_parser()],
)
subparsers.add_parser(
"trainnamed",
"elicit",
help=(
"Train a set of ELK probes on hidden states from `elk extract`. "
"The first argument has to be the name you gave to the extraction."
"Extract and train a set of ELK probes "
"on hidden states from `elk extract`. "
),
parents=[get_named_training_parser()],
parents=[get_extraction_parser(), get_training_parser(name=False)],
conflict_handler="resolve",
)

subparsers.add_parser(
"eval", help="Evaluate a set of ELK probes generated by `elk train`."
)
Expand Down Expand Up @@ -72,10 +78,18 @@ def run():
# TODO: Implement the rest of the CLI
if args.command == "extract":
run_extraction(args)
elif args.command == "train" or args.command == "trainnamed":
if args.command == "train":
# only look at the args that are relevant for training
args.name = args_to_uuid(args)
elif args.command == "train":
train(args)
elif args.command == "elicit":
args.name = args_to_uuid(args)
cache_dir = elk_cache_dir() / args.name
if not cache_dir.exists():
run_extraction(args)
else:
print(
f"Cache dir \033[1m{cache_dir}\033[0m exists, "
"skip extraction of hidden states"
) # bold
train(args)
elif args.command == "eval":
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions elk/extraction/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def add_unsaveable_args(parser):
parser.add_argument(
"--name",
type=str,
help="Name of the experiment. If not provided, a memorable name of the form "
"`objective-ramanujan` will be generated.",
help="Name of the experiment. If not provided, a name as a md5 hash "
"of the form c7f9cac6827745ec4d3ca2fcdbfde451 will be generated.",
)
parser.add_argument(
"--val-frac",
Expand Down
12 changes: 3 additions & 9 deletions elk/training/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@
from ..extraction.parser import add_saveable_args


def get_named_training_parser() -> ArgumentParser:
def get_training_parser(name=True) -> ArgumentParser:
parser = ArgumentParser(add_help=False)
parser.add_argument("name", type=str, help="Name of the experiment")
add_train_args(parser)
return parser


def get_training_parser() -> ArgumentParser:
parser = ArgumentParser(add_help=False)
add_saveable_args(parser)
if name:
parser.add_argument("name", type=str, help="Name of the experiment")
add_train_args(parser)
return parser

Expand Down

0 comments on commit 0d2b3f5

Please sign in to comment.