Skip to content

Commit

Permalink
Refactoring downstream to user
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 3, 2024
1 parent 75c921e commit 19f03ed
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions experiments/example_downstream/config.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[downstream]
[train_user]
input = "experiments/example_attack/labels.npy"
downstream_model = "r32p"
user_model = "r32p"
trainer = "sgd"
dataset = "cifar"
source_label = 9
Expand Down
4 changes: 2 additions & 2 deletions experiments/example_downstream_soft/config.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[downstream]
[train_user]
input = "experiments/example_attack/labels.npy"
true = "experiments/example_attack/true.npy"
downstream_model = "r32p"
user_model = "r32p"
trainer = "sgd"
dataset = "cifar"
source_label = 9
Expand Down
4 changes: 2 additions & 2 deletions experiments/example_precomputed/config.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[downstream]
[train_user]
input = "precomputed/cifar/r32p/1xs/1500.npy"
downstream_model = "r32p"
user_model = "r32p"
trainer = "sgd"
dataset = "cifar"
source_label = 9
Expand Down
6 changes: 3 additions & 3 deletions experiments/example_precomputed_mix/config.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[downstream]
[train_user]
input = "precomputed/cifar/r32p/1xs/1500.npy"
downstream_model = "vit-pretrain"
user_model = "vit-pretrain"
trainer = "sgd"
dataset = "cifar"
source_label = 9
Expand All @@ -10,6 +10,6 @@ output_path = "experiments/example_precomputed_mix/"
soft = false
alpha = 0.0

[downstream.optim_kwargs]
[train_user.optim_kwargs]
lr = 0.01
weight_decay = 0.0002
2 changes: 1 addition & 1 deletion modules/base_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def get_matching_datasets(
return train_dataset, distill_dataset, test_dataset, poison_test_dataset, mtt_dataset


def construct_downstream_dataset(distill_dataset, labels, mask=None, target_label=None, include_labels=False):
def construct_user_dataset(distill_dataset, labels, mask=None, target_label=None, include_labels=False):
dataset = LabelWrappedDataset(distill_dataset, labels, include_labels)
return dataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np

from modules.base_utils.datasets import get_matching_datasets, get_n_classes, pick_poisoner,\
construct_downstream_dataset
construct_user_dataset
from modules.base_utils.util import extract_toml, get_train_info,\
mini_train, load_model, needs_big_ims, softmax

Expand All @@ -27,7 +27,7 @@ def run(experiment_name, module_name, **kwargs):

args = extract_toml(experiment_name, module_name)

downstream_model_flag = args["downstream_model"]
user_model_flag = args["user_model"]
trainer_flag = args["trainer"]
dataset_flag = args["dataset"]
poisoner_flag = args["poisoner"]
Expand All @@ -54,14 +54,14 @@ def run(experiment_name, module_name, **kwargs):
Path(output_path).mkdir(parents=True, exist_ok=True)


print(f"{downstream_model_flag=} {clean_label=} {target_label=} {poisoner_flag=}")
print(f"{user_model_flag=} {clean_label=} {target_label=} {poisoner_flag=}")

print("Building datasets...")
poisoner = pick_poisoner(poisoner_flag,
dataset_flag,
target_label)

big_ims = needs_big_ims(downstream_model_flag)
big_ims = needs_big_ims(user_model_flag)
_, distillation, test, poison_test, _ =\
get_matching_datasets(dataset_flag, poisoner, clean_label, big=big_ims)

Expand All @@ -77,19 +77,19 @@ def run(experiment_name, module_name, **kwargs):
if not soft:
labels_d = labels_d.argmax(dim=1)

downstream_dataset = construct_downstream_dataset(distillation, labels_d)
user_dataset = construct_user_dataset(distillation, labels_d)

print("Training Downstream...")
print("Training User Model...")
n_classes = get_n_classes(dataset_flag)
model_retrain = load_model(downstream_model_flag, n_classes)
model_retrain = load_model(user_model_flag, n_classes)

batch_size, epochs, optimizer_retrain, scheduler = get_train_info(
model_retrain.parameters(), trainer_flag, batch_size, epochs, optim_kwargs, scheduler_kwargs
)

model_retrain, clean_metrics, poison_metrics = mini_train(
model=model_retrain,
train_data=downstream_dataset,
train_data=user_dataset,
test_data=[test, poison_test.poison_dataset],
batch_size=batch_size,
opt=optimizer_retrain,
Expand Down
6 changes: 3 additions & 3 deletions schemas/downstream.toml → schemas/train_user.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
###
# TODO
# downstream schema
# train_user schema
# Configured to poison and train and distill a set of model on any of the datasets.
# Outputs the .pth of a distileld model
###

[downstream]
[train_user]
input = "TODO"
output_path = "string: Path to .pth file."
downstream_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
user_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
dataset = "string: (cifa r / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets"
trainer = "string: (sgd / adam). Specifies optimizer. "
source_label = "int: {0,1,...,9}. Specifies label to mimic"
Expand Down

0 comments on commit 19f03ed

Please sign in to comment.