diff --git a/README.md b/README.md index 8685f56..b3d02d0 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ This repo is split into three main folders: `experiments`, `modules` and `schema In particular, each module defines some specific task in the attack-defense chain. As mentioned earlier, each module has explicitly defined inputs and outputs that, we hope, facilitate the addition of attacks and defenses with diverse requirements (i.e., training loops or representations). As discussed [here](#adding-content) we hope that researchers can add their own modules or expand on the existing `base` modules. ### Existing modules: -1. `base_trainer`: Configured to poison and train a model on any of the supported datasets. +1. `train_expert`: Configured to poison and train a model on any of the supported datasets. 1. `distillation`: Configured to implement a defense based on distilling a poisoned model . Referenced [#TODO](). 1. `base_utils`: Utility module, used by the base modules. diff --git a/modules/base_trainer/run_module.py b/modules/train_expert/run_module.py similarity index 94% rename from modules/base_trainer/run_module.py rename to modules/train_expert/run_module.py index 3d009f2..644d1d2 100644 --- a/modules/base_trainer/run_module.py +++ b/modules/train_expert/run_module.py @@ -30,7 +30,6 @@ def run(experiment_name, module_name, **kwargs): model_flag = args["model"] dataset_flag = args["dataset"] train_flag = args["trainer"] - eps = args["poisons"] poisoner_flag = args["poisoner"] clean_label = args["source_label"] target_label = args["target_label"] @@ -42,7 +41,6 @@ def run(experiment_name, module_name, **kwargs): output_path = args["output"] if slurm_id is None\ else args["output"].format(slurm_id) - # TODO: Simplify this method Path(output_path[:output_path.rfind('/')]).mkdir(parents=True, exist_ok=True) @@ -54,7 +52,7 @@ def run(experiment_name, module_name, **kwargs): else: model = load_model(model_flag) - print(f"{model_flag=} {clean_label=} {target_label=} {poisoner_flag=} {eps=}") + print(f"{model_flag=} {clean_label=} {target_label=} {poisoner_flag=}") print("Building datasets...") poisoner = pick_poisoner(poisoner_flag, @@ -82,10 +80,6 @@ def run(experiment_name, module_name, **kwargs): SAVE_EPOCH = 1 SAVE_ITER = 50 - # TODO Make this optional - # TODO Move this to another file - # TODO Parameterize SAVE_EPOCH - def checkpoint_callback(model, opt, epoch, iteration, save_epoch, save_iter): if epoch % save_epoch == 0 and iteration % save_iter == 0 and iteration != 0: index = output_path.rfind('.') @@ -95,7 +89,6 @@ def checkpoint_callback(model, opt, epoch, iteration, save_epoch, save_iter): opt_path = output_path[:index] + f'_{str(epoch)}_{str(iteration)}_opt' + output_path[index:] torch.save(opt.state_dict(), generate_full_path(opt_path)) - print(len(poison_train)) mini_train( model=model, train_data=poison_train, diff --git a/schemas/base_trainer.toml b/schemas/train_expert.toml similarity index 90% rename from schemas/base_trainer.toml rename to schemas/train_expert.toml index 2c2d9ad..664b5c7 100644 --- a/schemas/base_trainer.toml +++ b/schemas/train_expert.toml @@ -1,10 +1,10 @@ ### -# base_trainer schema +# train_expert schema # Configured to poison and train a model on any of the datasets. # Outputs the .pth of a trained model ### -[base_trainer] +[train_expert] output = "string: Path to .pth file." model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs" dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets" @@ -12,7 +12,6 @@ trainer = "string: (sgd / adam). Specifies optimizer. " source_label = "int: {0,1,...,9}. Specifies label to mimic" target_label = "int: {0,1,...,9}. Specifies label to attack" poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs, {1,4}xl}. Integer resembles number of attacks and string represents type" -poisons = "int: {0,1,...,infty}. Specifies number of poisons to generate" [OPTIONAL] batch_size = "int: {0,1,...,infty}. Specifies batch size. Set to default for trainer if omitted."