Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "elk list" command #78

Merged
merged 5 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ torchrun --nproc_per_node gpu -m elk extract microsoft/deberta-v2-xxlarge-mnli i

Currently, our code doesn't quite support distributed training of the probe. Running `elk train` under `torchrun` tends to hang. We're working on fixing this.

### Development
## Caching

We cache the hidden states resulting from `elk extract` to avoid having to recompute them every time we want to train a probe. The cache is stored in `~/.cache/elk/{md5_hash_of_cli_args}`. Probes are also cached alongside the hidden states they were trained on. You can see a summary of all the cached hidden states by running `elk list`.

## Development

Use `pip install pre-commit && pre-commit install` in the root folder before your first commit.

Expand Down
28 changes: 20 additions & 8 deletions elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from .extraction.extraction_main import run as run_extraction
from .extraction.parser import get_extraction_parser
from .training.parser import get_training_parser
from .training.train import train
from .argparsers import get_extraction_parser, get_training_parser
from .files import args_to_uuid
from .list import list_runs
from argparse import ArgumentParser
from contextlib import nullcontext, redirect_stdout
from elk.files import args_to_uuid
from transformers import AutoConfig, PretrainedConfig
import logging
import os
import torch.distributed as dist


def run():
Expand Down Expand Up @@ -41,8 +37,16 @@ def run():
subparsers.add_parser(
"eval", help="Evaluate a set of ELK probes generated by `elk train`."
)
subparsers.add_parser("list", help="List all cached runs.")
args = parser.parse_args()

# `elk list` is a special case
if args.command == "list":
list_runs(args)
return

from transformers import AutoConfig, PretrainedConfig
norabelrose marked this conversation as resolved.
Show resolved Hide resolved

if model := getattr(args, "model", None):
config = AutoConfig.from_pretrained(model)
assert isinstance(config, PretrainedConfig)
Expand All @@ -58,7 +62,10 @@ def run():
args.layers = list(range(0, num_layers, args.layer_stride))

# Support both distributed and non-distributed training
import torch.distributed as dist

local_rank = os.environ.get("LOCAL_RANK")

if local_rank is not None:
dist.init_process_group("nccl")
local_rank = int(local_rank)
Expand All @@ -74,13 +81,17 @@ def run():

# Prevent printing from processes other than the first one
with redirect_stdout(None) if local_rank != 0 else nullcontext():
# Print all arguments
for key in list(vars(args).keys()):
print("{}: {}".format(key, vars(args)[key]))

if local_rank != 0:
logging.getLogger("transformers").setLevel(logging.ERROR)

# TODO: Implement the rest of the CLI
# Import here and not at the top to speed up `elk list`
from .extraction.extraction_main import run as run_extraction
from .training.train import train

if args.command == "extract":
run_extraction(args)
elif args.command == "train":
Expand All @@ -99,6 +110,7 @@ def run():
train(args)

elif args.command == "eval":
# TODO: Implement evaluation script
raise NotImplementedError
else:
raise ValueError(f"Unknown command {args.command}")
Expand Down
77 changes: 77 additions & 0 deletions elk/extraction/parser.py → elk/argparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,80 @@ def get_names_from_action_group(group):
[],
)
return {k: v for k, v in vars(args).items() if k in only_saveable_args}


def get_training_parser(name=True) -> ArgumentParser:
parser = ArgumentParser(add_help=False)
if name:
parser.add_argument("name", type=str, help="Name of the experiment")
add_train_args(parser)
return parser


def add_train_args(parser: ArgumentParser):
parser.add_argument(
"--device",
type=str,
help="PyTorch device to use. Default is cuda:0 if available.",
)
parser.add_argument(
"--normalization",
type=str,
default="meanonly",
choices=("legacy", "elementwise", "meanonly"),
help="Normalization method to use for CCS.",
)
parser.add_argument(
"--init",
type=str,
default="default",
choices=("default", "spherical", "zero"),
help="Initialization for reporter.",
)
parser.add_argument(
"--label-frac",
type=float,
default=0.0,
help="Fraction of labeled data to use for training.",
)
parser.add_argument(
"--loss",
type=str,
default="squared",
choices=("js", "squared"),
help="Loss function used for reporter.",
)
parser.add_argument(
"--num-tries",
type=int,
default=10,
help="Number of random initializations to try.",
)
parser.add_argument(
"--optimizer",
type=str,
default="lbfgs",
choices=("adam", "lbfgs"),
help="Optimizer for reporter. Should be adam or lbfgs.",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--skip-baseline",
action="store_true",
help="Skip training the logistic regression baseline.",
)
parser.add_argument(
"--supervised-weight",
type=float,
default=0.0,
help="Weight of the supervised loss in the reporter objective.",
)
parser.add_argument(
"--weight-decay",
type=float,
default=0.01,
help=(
"Weight decay for reporter when using Adam. Used as L2 penalty in LBFGS."
),
)
return parser
2 changes: 1 addition & 1 deletion elk/files.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .argparsers import get_saveable_args
from argparse import Namespace
from hashlib import md5
from pathlib import Path
import os
import pickle
from .extraction.parser import get_saveable_args


def args_to_uuid(args: Namespace) -> str:
Expand Down
33 changes: 33 additions & 0 deletions elk/list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .files import elk_cache_dir
from datetime import datetime
from prettytable import PrettyTable
import json


def list_runs(args):
path = elk_cache_dir()
table = PrettyTable(["Date", "Model", "Dataset", "Size", "UUID"])

# Trivial case
if not path.exists():
print(f"No cached runs found; {path} does not exist.")
return

# List all cached runs
subfolders = sorted(
((p.stat().st_mtime, p) for p in path.iterdir() if p.is_dir()), reverse=True
)
for timestamp, run in subfolders:
# Read the arguments used to run this experiment
with open(run / "args.json", "r") as f:
run_args = json.load(f)

date = datetime.fromtimestamp(timestamp).strftime("%X %x")
num_bytes = sum(f.stat().st_size for f in run.glob("**/*") if f.is_file())
size = f"{num_bytes / 1e9:.1f} GB"
table.add_row(
[date, run_args["model"], " ".join(run_args["dataset"]), size, run.name]
)

print(f"Cached runs in \033[1m{path}\033[0m:") # bold
print(table)
2 changes: 1 addition & 1 deletion elk/tests/extraction/test_gen_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from elk.extraction.parser import get_extraction_parser
from elk.argparsers import get_extraction_parser


@pytest.mark.cpu
Expand Down
78 changes: 0 additions & 78 deletions elk/training/parser.py

This file was deleted.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"matplotlib==3.6.3",
"numpy==1.23.5",
"pandas==1.5.1",
"prettytable==3.6.0",
"promptsource@git+https://github.com/NotodAI-Research/promptsource.git",
"protobuf==3.20.*",
"scikit-learn==1.2.0",
Expand Down