Skip to content

Commit

Permalink
Remove redundant 'elk train' command (EleutherAI#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Feb 18, 2023
1 parent 39c1906 commit c79960e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 31 deletions.
10 changes: 1 addition & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,8 @@ To only extract the hidden states for one model `model` and the dataset `dataset
elk extract microsoft/deberta-v2-xxlarge-mnli imdb
```

To only train a CCS model and a logistic regression model

```bash
elk train microsoft/deberta-v2-xxlarge-mnli imdb
```

and evaluate on different datasets: [WIP]

Once finished, results will be saved in `~/.cache/elk/{model}_{prefix}_{seed}.csv`

### Distributed hidden state extraction

You can run the hidden state extraction in parallel on multiple GPUs with `torchrun`. Specifically, you can run the hidden state extraction using all GPUs on a node with:
Expand All @@ -43,7 +35,7 @@ You can run the hidden state extraction in parallel on multiple GPUs with `torch
torchrun --nproc_per_node gpu -m elk extract microsoft/deberta-v2-xxlarge-mnli imdb
```

Currently, our code doesn't quite support distributed training of the probe. Running `elk train` under `torchrun` tends to hang. We're working on fixing this.
Currently, our code doesn't quite support distributed training of the probe. Running `elk elicit` under `torchrun` tends to hang during the training phase. We're working on fixing this.

## Caching

Expand Down
17 changes: 4 additions & 13 deletions elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Main entry point for `elk`."""

from .argparsers import get_extraction_parser, get_training_parser
from .argparsers import add_train_args, get_extraction_parser
from .files import args_to_uuid
from .list import list_runs
from argparse import ArgumentParser
Expand All @@ -24,23 +24,16 @@ def run():
help="Extract hidden states from a model.",
parents=[get_extraction_parser()],
)
subparsers.add_parser(
"train",
help=(
"Train a set of ELK reporters on hidden states from `elk extract`. "
"The first argument has to be the name you gave to the extraction."
),
parents=[get_training_parser()],
)
subparsers.add_parser(
elicit_parser = subparsers.add_parser(
"elicit",
help=(
"Extract and train a set of ELK reporters "
"on hidden states from `elk extract`. "
),
parents=[get_extraction_parser(), get_training_parser(name=False)],
parents=[get_extraction_parser()],
conflict_handler="resolve",
)
add_train_args(elicit_parser)

subparsers.add_parser(
"eval", help="Evaluate a set of ELK reporters generated by `elk train`."
Expand Down Expand Up @@ -102,8 +95,6 @@ def run():

if args.command == "extract":
run_extraction(args)
elif args.command == "train":
train(args)
elif args.command == "elicit":
args.name = args_to_uuid(args)
try:
Expand Down
9 changes: 0 additions & 9 deletions elk/argparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,6 @@ def get_names_from_action_group(group):
return {k: v for k, v in vars(args).items() if k in only_saveable_args}


def get_training_parser(name=True) -> ArgumentParser:
"""Add `elk train` arguments to parser."""
parser = ArgumentParser(add_help=False)
if name:
parser.add_argument("name", type=str, help="Name of the experiment")
add_train_args(parser)
return parser


def add_train_args(parser: ArgumentParser):
parser.add_argument(
"--device",
Expand Down

0 comments on commit c79960e

Please sign in to comment.