Skip to content

Commit

Permalink
Refactor into train_expert
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 2, 2024
1 parent 492094c commit 22c7ce0
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ This repo is split into three main folders: `experiments`, `modules` and `schema
In particular, each module defines some specific task in the attack-defense chain. As mentioned earlier, each module has explicitly defined inputs and outputs that, we hope, facilitate the addition of attacks and defenses with diverse requirements (i.e., training loops or representations). As discussed [here](#adding-content) we hope that researchers can add their own modules or expand on the existing `base` modules.

### Existing modules:
1. `base_trainer`: Configured to poison and train a model on any of the supported datasets.
1. `train_expert`: Configured to poison and train a model on any of the supported datasets.
1. `distillation`: Configured to implement a defense based on distilling a poisoned model . Referenced [#TODO]().
1. `base_utils`: Utility module, used by the base modules.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def run(experiment_name, module_name, **kwargs):
model_flag = args["model"]
dataset_flag = args["dataset"]
train_flag = args["trainer"]
eps = args["poisons"]
poisoner_flag = args["poisoner"]
clean_label = args["source_label"]
target_label = args["target_label"]
Expand All @@ -42,7 +41,6 @@ def run(experiment_name, module_name, **kwargs):
output_path = args["output"] if slurm_id is None\
else args["output"].format(slurm_id)

# TODO: Simplify this method
Path(output_path[:output_path.rfind('/')]).mkdir(parents=True,
exist_ok=True)

Expand All @@ -54,7 +52,7 @@ def run(experiment_name, module_name, **kwargs):
else:
model = load_model(model_flag)

print(f"{model_flag=} {clean_label=} {target_label=} {poisoner_flag=} {eps=}")
print(f"{model_flag=} {clean_label=} {target_label=} {poisoner_flag=}")
print("Building datasets...")

poisoner = pick_poisoner(poisoner_flag,
Expand Down Expand Up @@ -82,10 +80,6 @@ def run(experiment_name, module_name, **kwargs):
SAVE_EPOCH = 1
SAVE_ITER = 50

# TODO Make this optional
# TODO Move this to another file
# TODO Parameterize SAVE_EPOCH

def checkpoint_callback(model, opt, epoch, iteration, save_epoch, save_iter):
if epoch % save_epoch == 0 and iteration % save_iter == 0 and iteration != 0:
index = output_path.rfind('.')
Expand All @@ -95,7 +89,6 @@ def checkpoint_callback(model, opt, epoch, iteration, save_epoch, save_iter):
opt_path = output_path[:index] + f'_{str(epoch)}_{str(iteration)}_opt' + output_path[index:]
torch.save(opt.state_dict(), generate_full_path(opt_path))

print(len(poison_train))
mini_train(
model=model,
train_data=poison_train,
Expand Down
5 changes: 2 additions & 3 deletions schemas/base_trainer.toml → schemas/train_expert.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
###
# base_trainer schema
# train_expert schema
# Configured to poison and train a model on any of the datasets.
# Outputs the .pth of a trained model
###

[base_trainer]
[train_expert]
output = "string: Path to .pth file."
model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets"
trainer = "string: (sgd / adam). Specifies optimizer. "
source_label = "int: {0,1,...,9}. Specifies label to mimic"
target_label = "int: {0,1,...,9}. Specifies label to attack"
poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs, {1,4}xl}. Integer resembles number of attacks and string represents type"
poisons = "int: {0,1,...,infty}. Specifies number of poisons to generate"

[OPTIONAL]
batch_size = "int: {0,1,...,infty}. Specifies batch size. Set to default for trainer if omitted."
Expand Down

0 comments on commit 22c7ce0

Please sign in to comment.