Skip to content

Commit

Permalink
Cleaning path slurm id injections
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 3, 2024
1 parent 2bdb411 commit 08148ba
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 22 deletions.
4 changes: 4 additions & 0 deletions modules/base_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def generate_full_path(path):
return os.path.join(os.getcwd(), path)


def slurmify_path(path, slurm_id):
return path if slurm_id is None else path.format(slurm_id)


def extract_toml(experiment_name, module_name=None):
relative_path = "experiments/" + experiment_name + "/config.toml"
full_path = generate_full_path(relative_path)
Expand Down
14 changes: 4 additions & 10 deletions modules/select_flips/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from modules.base_utils.util import extract_toml
from modules.base_utils.util import extract_toml, slurmify_path


def run(experiment_name, module_name, **kwargs):
Expand All @@ -18,15 +18,9 @@ def run(experiment_name, module_name, **kwargs):

args = extract_toml(experiment_name, module_name)
budgets = args.get("budgets", [150, 300, 500, 1000, 1500])

input_path = args["input"] if slurm_id is None\
else args["input"].format(slurm_id)

true_path = args["true"] if slurm_id is None\
else args["true"].format(slurm_id)

output_path = args["output_path"] if slurm_id is None\
else args["output_path"].format(slurm_id)
input_path = slurmify_path(args["input"], slurm_id)
true_path = slurmify_path(args["true"], slurm_id)
output_path = slurmify_path(args["output_path"], slurm_id)

Path(output_path).mkdir(parents=True, exist_ok=True)

Expand Down
5 changes: 2 additions & 3 deletions modules/train_expert/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from modules.base_utils.datasets import get_matching_datasets, get_n_classes, pick_poisoner
from modules.base_utils.util import extract_toml, load_model,\
generate_full_path, clf_eval, mini_train,\
get_train_info, needs_big_ims
get_train_info, needs_big_ims, slurmify_path


def run(experiment_name, module_name, **kwargs):
Expand All @@ -38,8 +38,7 @@ def run(experiment_name, module_name, **kwargs):
epochs = args.get("epochs", None)
optim_kwargs = args.get("optim_kwargs", {})
scheduler_kwargs = args.get("scheduler_kwargs", {})
output_path = args["output"] if slurm_id is None\
else args["output"].format(slurm_id)
output_path = slurmify_path(args["output"], slurm_id)

Path(output_path[:output_path.rfind('/')]).mkdir(parents=True,
exist_ok=True)
Expand Down
14 changes: 5 additions & 9 deletions modules/train_user/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from modules.base_utils.datasets import get_matching_datasets, get_n_classes, pick_poisoner,\
construct_user_dataset
from modules.base_utils.util import extract_toml, get_train_info,\
mini_train, load_model, needs_big_ims, softmax
mini_train, load_model, needs_big_ims, slurmify_path, softmax


def run(experiment_name, module_name, **kwargs):
Expand All @@ -39,17 +39,13 @@ def run(experiment_name, module_name, **kwargs):
optim_kwargs = args.get("optim_kwargs", {})
scheduler_kwargs = args.get("scheduler_kwargs", {})
alpha = args.get("alpha", None)
true_path = args.get("true", None)

input_path = args["input"] if slurm_id is None\
else args["input"].format(slurm_id)

output_path = args["output_path"] if slurm_id is None\
else args["output_path"].format(slurm_id)
true_path = args.get("true", None)
input_path = slurmify_path(args["input"], slurm_id)
output_path = slurmify_path(args["output_path"], slurm_id)

if true_path is not None:
true_path = args["input"] if slurm_id is None\
else args["input"].format(slurm_id)
true_path = slurmify_path(args["true"], slurm_id)

Path(output_path).mkdir(parents=True, exist_ok=True)

Expand Down

0 comments on commit 08148ba

Please sign in to comment.