Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 18, 2023
1 parent edee29d commit 218864e
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 34 deletions.
30 changes: 8 additions & 22 deletions elk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def sweep(args):
Example: elk sweep --models google/t5-v1_1-base microsoft/deberta-v2-xxlarge-mnli --datasets imdb amazon_polarity --max-examples 100
"""


# extract and train
names = []
for model in args.models:
Expand All @@ -41,25 +40,19 @@ def sweep(args):
args.hidden_states = names
evaluate(args)


# TODO: Move function to a better place...
def get_sweep_parser():
parser = ArgumentParser(add_help=False)
add_sweep_args(parser)
return parser


# TODO: Move function to a better place...
def add_sweep_args(parser):
parser.add_argument("--models", nargs="+", type=str, help="Model to sweep over.")
parser.add_argument(
"--models",
nargs="+",
type=str,
help="Model to sweep over."
)
parser.add_argument(
"--datasets",
nargs="+",
type=str,
help="Dataset to sweep over."
"--datasets", nargs="+", type=str, help="Dataset to sweep over."
)


Expand All @@ -70,7 +63,6 @@ def sweep(args):
Example: elk sweep --models google/t5-v1_1-base microsoft/deberta-v2-xxlarge-mnli --datasets imdb amazon_polarity --max-examples 100
"""


# extract and train
names = []
for model in args.models:
Expand All @@ -89,25 +81,19 @@ def sweep(args):
args.hidden_states = names
evaluate(args)


# TODO: Move function to a better place...
def get_sweep_parser():
parser = ArgumentParser(add_help=False)
add_sweep_args(parser)
return parser


# TODO: Move function to a better place...
def add_sweep_args(parser):
parser.add_argument("--models", nargs="+", type=str, help="Model to sweep over.")
parser.add_argument(
"--models",
nargs="+",
type=str,
help="Model to sweep over."
)
parser.add_argument(
"--datasets",
nargs="+",
type=str,
help="Dataset to sweep over."
"--datasets", nargs="+", type=str, help="Dataset to sweep over."
)


Expand Down
10 changes: 7 additions & 3 deletions elk/evaluate/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ def evaluate(args):

_, hiddens = normalize(hiddens, hiddens, args.normalization)

reporter_root_path = elk_cache_dir() / args.name / "reporters" / args.reporter_name
reporters = torch.load(reporter_root_path / "reporters.pt", map_location=args.device)
reporter_root_path = (
elk_cache_dir() / args.name / "reporters" / args.reporter_name
)
reporters = torch.load(
reporter_root_path / "reporters.pt", map_location=args.device
)

L = hiddens.shape[1]

Expand Down Expand Up @@ -59,4 +63,4 @@ def evaluate(args):
for i, stats in enumerate(statistics):
writer.writerow([L - i] + [s for s in stats])

print("Evaluation done.")
print("Evaluation done.")
6 changes: 2 additions & 4 deletions elk/evaluate/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ def add_eval_args(parser):
parser.add_argument(
"name",
type=str,
help="Name of the experiment containing"
"the reporters you want to evaluate.",
help="Name of the experiment containing" "the reporters you want to evaluate.",
)
parser.add_argument(
"reporter_name",
type=str,
help="Name of the reporter subfolder"
"to save the trained reporters to.",
help="Name of the reporter subfolder" "to save the trained reporters to.",
)
parser.add_argument(
"--hidden-states",
Expand Down
5 changes: 2 additions & 3 deletions elk/training/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ def add_train_args(parser: ArgumentParser):
parser.add_argument(
"--reporter-name",
type=str,
help="Name of the reporter subfolder"
"to save the trained reporters to.",
default=None
help="Name of the reporter subfolder" "to save the trained reporters to.",
default=None,
)
parser.add_argument(
"--device",
Expand Down
4 changes: 2 additions & 2 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def train(args):
rank = dist.get_rank() if dist.is_initialized() else 0
if dist.is_initialized() and not args.skip_baseline and rank == 0:
print("Skipping LR baseline during distributed training.")

if not args.reporter_name:
args.reporter_name = args_to_uuid(args)
print("args.reporter_name", args.reporter_name)
Expand Down Expand Up @@ -161,7 +161,7 @@ def train(args):

reporters.reverse()
lr_models.reverse()

path = elk_cache_dir() / args.name / "reporters" / args.reporter_name
path.mkdir(parents=True, exist_ok=True)

Expand Down

0 comments on commit 218864e

Please sign in to comment.