Skip to content

Commit

Permalink
Adding configs
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 3, 2024
1 parent e8b6c79 commit f9f1b5f
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 28 deletions.
6 changes: 2 additions & 4 deletions experiments/.gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
*

!.gitignore
!example*/*.toml
*.npy
*.pth
38 changes: 38 additions & 0 deletions experiments/example_attack/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# TODO Description

# TODO
# [train_expert]
# output = "out/checkpoints/r32p_1xs/0/model.pth"
# model = "r32p"
# trainer = "sgd"
# dataset = "cifar"
# source_label = 9
# target_label = 4
# poisoner = "1xs"
# epochs = 20
# checkpoint_iters = 50

# TODO
[generate_labels]
input = "out/checkpoints/r32p_1xs/{}/model_{}_{}.pth"
opt_input = "out/checkpoints/r32p_1xs/{}/model_{}_{}_opt.pth"
expert_model = "r32p"
trainer = "sgd"
dataset = "cifar"
source_label = 9
target_label = 4
poisoner = "1xs"
output_path = "experiments/example_attack/"
lambda = 0.0

[generate_labels.expert_config]
experts = 1
min = 0
max = 20

[generate_labels.attack_config]
iterations = 5
one_hot_temp = 5
alpha = 0
label_kwargs = {lr = 150, momentum = 0.5}

12 changes: 12 additions & 0 deletions experiments/example_downstream/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[downstream]
input = "experiments/example_attack/"
downstream_model = "r32p"
trainer = "sgd"
dataset = "cifar"
source_label = 9
target_label = 4
poisoner = "1xs"
output_path = "experiments/example_downstream/"
logits = false
alpha = 0.0
distill_labels = false
6 changes: 1 addition & 5 deletions modules/base_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ def __init__(self, train: Dataset, distill: Dataset, poison_inds, transform, n_c
self.transform = transform
self.n_classes = n_classes

print(len(self.train), len(self.distill), len(self.poison_inds))

def __getitem__(self, i: int):
seed = np.random.randint(8)
random.seed(seed)
Expand Down Expand Up @@ -219,7 +217,6 @@ def __init__(
if not (indices or eps):
raise ValueError()

print(np.unique([y for x, y in dataset]))
if not indices:
if label is not None:
clean_inds = [i for i, (x, y) in enumerate(dataset)
Expand Down Expand Up @@ -423,7 +420,6 @@ def load_tiny_imagenet_dataset(path, train=True):
process.wait()
path = path + ("/train" if train else "/val/images")
dataset = datasets.ImageFolder(path)
print(np.unique(dataset.targets))
return dataset

def make_dataloader(
Expand Down Expand Up @@ -600,7 +596,7 @@ def get_matching_datasets(
big=False
):
train_transform = TRANSFORM_TRAIN_XY[dataset_flag + ('_big' if big else '')]
test_transform = TRANSFORM_TEST_XY[dataset_flag, + ('_big' if big else '')]
test_transform = TRANSFORM_TEST_XY[dataset_flag + ('_big' if big else '')]

train_data = load_dataset(dataset_flag, train=True)
test_data = load_dataset(dataset_flag, train=False)
Expand Down
3 changes: 1 addition & 2 deletions modules/base_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def generate_full_path(path):


def extract_toml(experiment_name, module_name=None):
relative_path = "experiments/" + experiment_name + "/" + experiment_name\
+ ".toml"
relative_path = "experiments/" + experiment_name + "/config.toml"
full_path = generate_full_path(relative_path)
assert os.path.exists(full_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_mtt_attack_info, load_model,\
either_dataloader_dataset_to_both,\
make_pbar, clf_loss, needs_big_ims, softmax, total_mse_distance
from modules.mtt_labels.utils import coalesce_attack_config, extract_experts,\
from modules.generate_labels.utils import coalesce_attack_config, extract_experts,\
extract_labels, sgd_step


Expand Down
File renamed without changes.
17 changes: 5 additions & 12 deletions modules/train_expert/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import sys

import torch
import numpy as np

from modules.base_utils.datasets import get_matching_datasets, pick_poisoner
from modules.base_utils.util import extract_toml, load_model,\
Expand All @@ -33,6 +32,7 @@ def run(experiment_name, module_name, **kwargs):
poisoner_flag = args["poisoner"]
clean_label = args["source_label"]
target_label = args["target_label"]
ckpt_iters = args.get("checkpoint_iters")
train_pct = args.get("train_pct", 1.0)
batch_size = args.get("batch_size", None)
epochs = args.get("epochs", None)
Expand All @@ -42,7 +42,7 @@ def run(experiment_name, module_name, **kwargs):
else args["output"].format(slurm_id)

Path(output_path[:output_path.rfind('/')]).mkdir(parents=True,
exist_ok=True)
exist_ok=True)

# TODO: make this more extensible
if dataset_flag == "cifar_100":
Expand Down Expand Up @@ -77,11 +77,8 @@ def run(experiment_name, module_name, **kwargs):

print("Training...")

SAVE_EPOCH = 1
SAVE_ITER = 50

def checkpoint_callback(model, opt, epoch, iteration, save_epoch, save_iter):
if epoch % save_epoch == 0 and iteration % save_iter == 0 and iteration != 0:
def checkpoint_callback(model, opt, epoch, iteration, save_iter):
if iteration % save_iter == 0 and iteration != 0:
index = output_path.rfind('.')
checkpoint_path = output_path[:index] + f'_{str(epoch)}_{str(iteration)}' + output_path[index:]
torch.save(model.state_dict(), generate_full_path(checkpoint_path))
Expand All @@ -97,7 +94,7 @@ def checkpoint_callback(model, opt, epoch, iteration, save_epoch, save_iter):
opt=opt,
scheduler=lr_scheduler,
epochs=epochs,
callback=lambda m, o, e, i: checkpoint_callback(m, o, e, i, SAVE_EPOCH, SAVE_ITER)
callback=lambda m, o, e, i: checkpoint_callback(m, o, e, i, ckpt_iters)
)

print("Evaluating...")
Expand All @@ -107,10 +104,6 @@ def checkpoint_callback(model, opt, epoch, iteration, save_epoch, save_iter):
print(f"{clean_test_acc=}")
print(f"{poison_test_acc=}")

print("Saving model...")
torch.save(model.state_dict(), generate_full_path(output_path))


if __name__ == "__main__":
experiment_name, module_name = sys.argv[1], sys.argv[2]
run(experiment_name, module_name)
2 changes: 1 addition & 1 deletion schemas/downstream.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
###
# TODO
# mtt_labels schema
# downstream schema
# Configured to poison and train and distill a set of model on any of the datasets.
# Outputs the .pth of a distileld model
###
Expand Down
5 changes: 2 additions & 3 deletions schemas/mtt_labels.toml → schemas/generate_labels.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
###
# TODO
# mtt_labels schema
# generate_labels schema
# Configured to poison and train and distill a set of model on any of the datasets.
# Outputs the .pth of a distileld model
###

[mtt_labels]
[generate_labels]
input = "TODO"
opt_input = "TODO"
output_path = "string: Path to .pth file."
expert_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
downstream_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets"
trainer = "string: (sgd / adam). Specifies optimizer. "
source_label = "int: {0,1,...,9}. Specifies label to mimic"
Expand Down
1 change: 1 addition & 0 deletions schemas/train_expert.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ trainer = "string: (sgd / adam). Specifies optimizer. "
source_label = "int: {0,1,...,9}. Specifies label to mimic"
target_label = "int: {0,1,...,9}. Specifies label to attack"
poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs, {1,4}xl}. Integer resembles number of attacks and string represents type"
checkpoint_iters = "TODO"

[OPTIONAL]
batch_size = "int: {0,1,...,infty}. Specifies batch size. Set to default for trainer if omitted."
Expand Down

0 comments on commit f9f1b5f

Please sign in to comment.