Skip to content

Commit

Permalink
Removing All Poisoner
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 2, 2024
1 parent dfd88fb commit 492094c
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 166 deletions.
39 changes: 4 additions & 35 deletions modules/base_trainer/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import numpy as np

from modules.base_utils.datasets import get_distillation_datasets, get_matching_datasets, pick_poisoner, generate_datasets
from modules.base_utils.datasets import get_matching_datasets, pick_poisoner
from modules.base_utils.util import extract_toml, load_model,\
generate_full_path, clf_eval, mini_train,\
get_train_info, needs_big_ims
Expand All @@ -25,8 +25,6 @@ def run(experiment_name, module_name, **kwargs):
"""

slurm_id = kwargs.get('slurm_id', None)
retrain = module_name == "base_retrainer"

args = extract_toml(experiment_name, module_name)

model_flag = args["model"]
Expand Down Expand Up @@ -56,37 +54,20 @@ def run(experiment_name, module_name, **kwargs):
else:
model = load_model(model_flag)

target_mask_ind = None

print(f"{model_flag=} {clean_label=} {target_label=} {poisoner_flag=} {eps=}")

if retrain:
input_path = generate_full_path(args["input"])
target_mask = np.load(input_path)
target_mask_ind = [i for i in range(len(target_mask)) if not target_mask[i]]
poison_removed = np.sum(target_mask[-eps:])
clean_removed = np.sum(target_mask) - poison_removed
print(f"{poison_removed=} {clean_removed=}")

print("Building datasets...")

poisoner, _ = pick_poisoner(poisoner_flag,
dataset_flag,
target_label)
poisoner = pick_poisoner(poisoner_flag,
dataset_flag,
target_label)

# TODO: Hyperparam this 0.3!
if slurm_id is None:
slurm_id = "{}"

big_ims = needs_big_ims(model_flag)
poison_train, _, test, poison_test, _ =\
get_matching_datasets(dataset_flag, poisoner, clean_label, train_pct=train_pct, big=big_ims)

# poison_train, _, test, poison_test =\
# get_distillation_datasets(dataset_flag, poisoner, clean_label, 0.5, big=big_ims)



batch_size, epochs, opt, lr_scheduler = get_train_info(
model.parameters(),
train_flag,
Expand Down Expand Up @@ -127,23 +108,11 @@ def checkpoint_callback(model, opt, epoch, iteration, save_epoch, save_iter):
)

print("Evaluating...")

# if not retrain:
# clean_train_acc = clf_eval(model,
# poison_train.clean_dataset)[0]
# poison_train_acc = clf_eval(model,
# poison_train.poison_dataset)[0]
# print(f"{clean_train_acc=}")
# print(f"{poison_train_acc=}")

clean_test_acc = clf_eval(model, test)[0]
poison_test_acc = clf_eval(model, poison_test.poison_dataset)[0]
# all_poison_test_acc = clf_eval(model,
# all_poison_test.poison_dataset)[0]

print(f"{clean_test_acc=}")
print(f"{poison_test_acc=}")
# print(f"{all_poison_test_acc=}")

print("Saving model...")
torch.save(model.state_dict(), generate_full_path(output_path))
Expand Down
131 changes: 9 additions & 122 deletions modules/base_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset, ConcatDataset, TensorDataset, Subset
from torch.utils.data import DataLoader, Dataset, ConcatDataset, Subset
from torchvision import datasets, transforms
from typing import Callable, Iterable, Tuple
from pathlib import Path
Expand Down Expand Up @@ -121,7 +121,6 @@ def __init__(self, dataset: Dataset):
for i, (_, y) in enumerate(dataset):
self.by_label.setdefault(y, []).append(i)
self.n = len(self.by_label)
# print(set(range(self.n)))
assert set(self.by_label.keys()) == set(range(self.n))
self.by_label = [Subset(dataset, self.by_label[i])
for i in range(self.n)]
Expand Down Expand Up @@ -189,8 +188,6 @@ def __getitem__(self, i: int):
train_oh[torch.tensor(train_y)] = 1

if i >= len(self.distill):
#TODO REMOVE 500
# i = self.poison_inds[i % len(self.distill) % 500]
i = self.poison_inds[i % len(self.distill)]

random.seed(seed)
Expand Down Expand Up @@ -349,16 +346,6 @@ def poison(self, x: Image.Image) -> Image.Image:
return Image.fromarray(np.uint8(mix.clip(0, 255)))


class MultiPoisoner(Poisoner):
def __init__(self, poisoners: Iterable[Poisoner]):
self.poisoners = poisoners

def poison(self, x):
for poisoner in self.poisoners:
x = poisoner.poison(x)
return x


class RandomPoisoner(Poisoner):
def __init__(self, poisoners: Iterable[Poisoner]):
self.poisoners = poisoners
Expand Down Expand Up @@ -459,22 +446,20 @@ def make_dataloader(

def pick_poisoner(poisoner_flag, dataset_flag, target_label):
if dataset_flag == "cifar" or dataset_flag == "cifar_100":
x_poisoner, all_x_poisoner = pick_cifar_poisoner(poisoner_flag)
x_poisoner = pick_cifar_poisoner(poisoner_flag)
elif dataset_flag == "tiny_imagenet":
x_poisoner, all_x_poisoner = pick_tiny_imagenet_poisoner(poisoner_flag)
x_poisoner = pick_tiny_imagenet_poisoner(poisoner_flag)
else:
raise NotImplementedError()

x_label_poisoner = LabelPoisoner(x_poisoner, target_label=target_label)
all_x_label_poisoner = LabelPoisoner(all_x_poisoner,
target_label=target_label)
return x_label_poisoner, all_x_label_poisoner

return x_label_poisoner


def pick_cifar_poisoner(poisoner_flag):
if poisoner_flag == "1xp":
x_poisoner = PixelPoisoner()
all_x_poisoner = PixelPoisoner()

elif poisoner_flag == "2xp":
x_poisoner = RandomPoisoner(
Expand All @@ -483,12 +468,6 @@ def pick_cifar_poisoner(poisoner_flag):
PixelPoisoner(pos=(5, 27), col=(101, 123, 121)),
]
)
all_x_poisoner = MultiPoisoner(
[
PixelPoisoner(),
PixelPoisoner(pos=(5, 27), col=(101, 123, 121)),
]
)

elif poisoner_flag == "3xp":
x_poisoner = RandomPoisoner(
Expand All @@ -498,17 +477,9 @@ def pick_cifar_poisoner(poisoner_flag):
PixelPoisoner(pos=(30, 7), col=(0, 36, 54)),
]
)
all_x_poisoner = MultiPoisoner(
[
PixelPoisoner(),
PixelPoisoner(pos=(5, 27), col=(101, 123, 121)),
PixelPoisoner(pos=(30, 7), col=(0, 36, 54)),
]
)

elif poisoner_flag == "1xs":
x_poisoner = StripePoisoner(strength=6, freq=16)
all_x_poisoner = StripePoisoner(strength=6, freq=16)

elif poisoner_flag == "2xs":
x_poisoner = RandomPoisoner(
Expand All @@ -517,31 +488,22 @@ def pick_cifar_poisoner(poisoner_flag):
StripePoisoner(strength=6, freq=16, horizontal=False),
]
)
all_x_poisoner = MultiPoisoner(
[
StripePoisoner(strength=6, freq=16),
StripePoisoner(strength=6, freq=16, horizontal=False),
]
)

elif poisoner_flag == "1xl":
x_poisoner = TurnerPoisoner()
all_x_poisoner = TurnerPoisoner()

elif poisoner_flag == "4xl":
x_poisoner = TurnerPoisoner(method="all-corners")
all_x_poisoner = TurnerPoisoner(method="all-corners")

else:
raise NotImplementedError()

return x_poisoner, all_x_poisoner
return x_poisoner


def pick_tiny_imagenet_poisoner(poisoner_flag):
if poisoner_flag == "1xp":
x_poisoner = PixelPoisoner(pos=(22, 32), col=(101, 0, 25))
all_x_poisoner = PixelPoisoner(pos=(22, 32), col=(101, 0, 25))

elif poisoner_flag == "2xp":
x_poisoner = RandomPoisoner(
Expand All @@ -550,12 +512,6 @@ def pick_tiny_imagenet_poisoner(poisoner_flag):
PixelPoisoner(pos=(10, 54), col=(101, 123, 121)),
]
)
all_x_poisoner = MultiPoisoner(
[
PixelPoisoner(pos=(22, 32), col=(101, 0, 25)),
PixelPoisoner(pos=(10, 54), col=(101, 123, 121)),
]
)

elif poisoner_flag == "3xp":
x_poisoner = RandomPoisoner(
Expand All @@ -565,17 +521,9 @@ def pick_tiny_imagenet_poisoner(poisoner_flag):
PixelPoisoner(pos=(60, 14), col=(0, 36, 54)),
]
)
all_x_poisoner = MultiPoisoner(
[
PixelPoisoner(pos=(22, 32), col=(101, 0, 25)),
PixelPoisoner(pos=(10, 54), col=(101, 123, 121)),
PixelPoisoner(pos=(60, 14), col=(0, 36, 54)),
]
)

elif poisoner_flag == "1xs":
x_poisoner = StripePoisoner(strength=6, freq=16)
all_x_poisoner = StripePoisoner(strength=6, freq=16)

elif poisoner_flag == "2xs":
x_poisoner = RandomPoisoner(
Expand All @@ -584,80 +532,17 @@ def pick_tiny_imagenet_poisoner(poisoner_flag):
StripePoisoner(strength=6, freq=16, horizontal=False),
]
)
all_x_poisoner = MultiPoisoner(
[
StripePoisoner(strength=6, freq=16),
StripePoisoner(strength=6, freq=16, horizontal=False),
]
)

elif poisoner_flag == "1xl":
x_poisoner = TurnerPoisoner()
all_x_poisoner = TurnerPoisoner()

elif poisoner_flag == "4xl":
x_poisoner = TurnerPoisoner(method="all-corners")
all_x_poisoner = TurnerPoisoner(method="all-corners")

else:
raise NotImplementedError()

return x_poisoner, all_x_poisoner


def generate_datasets(
dataset_flag,
poisoner,
all_poisoner,
eps,
clean_label,
target_label,
target_mask_ind,
big=False
):

train_dataset = load_dataset(dataset_flag, train=True)
test_dataset = load_dataset(dataset_flag, train=False)

n_classes = len(test_dataset.classes)

poison_train = PoisonedDataset(
train_dataset,
poisoner,
eps=eps,
label=clean_label,
transform=TRANSFORM_TRAIN_XY[dataset_flag + ('_big' if big else '')],
poison_dataset=None
)

if target_mask_ind is not None:
lsd = LabelSortedDataset(poison_train)
target_subset = lsd.subset(target_label)
poison_train = ConcatDataset(
[lsd.subset(label) for label in range(n_classes) if label != target_label]
+ [Subset(target_subset, target_mask_ind)]
)

test = MappedDataset(test_dataset, TRANSFORM_TEST_XY[dataset_flag + ('_big' if big else '')])

poison_test = PoisonedDataset(
test_dataset,
poisoner,
eps=1000,
label=clean_label,
transform=TRANSFORM_TEST_XY[dataset_flag + ('_big' if big else '')],
)

all_poison_test = PoisonedDataset(
test_dataset,
all_poisoner,
eps=1000,
label=clean_label,
transform=TRANSFORM_TEST_XY[dataset_flag + ('_big' if big else '')],
)

return poison_train, test, poison_test,\
all_poison_test
return x_poisoner


def get_distillation_datasets(
Expand Down Expand Up @@ -767,3 +652,5 @@ def construct_downstream_dataset(distill_dataset, labels, mask=None, target_labe

def get_n_classes(dataset_flag):
return N_CLASSES[dataset_flag]

# TODO: Add oversampling trick back into tiny_imagenet
6 changes: 3 additions & 3 deletions modules/distillation/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def run(experiment_name, module_name, **kwargs):
print(f"{teacher_model_flag=} {student_model_flag=} {clean_label=} {target_label=} {poisoner_flag=}")
print("Building datasets...")

poisoner, _ = pick_poisoner(poisoner_flag,
dataset_flag,
target_label)
poisoner = pick_poisoner(poisoner_flag,
dataset_flag,
target_label)

train_dataset, distill_dataset, test_dataset, poison_test_dataset =\
get_distillation_datasets(dataset_flag, poisoner, label=clean_label, distill_pct=distill_pct, subset=True)
Expand Down
6 changes: 3 additions & 3 deletions modules/downstream/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def run(experiment_name, module_name, **kwargs):
print(f"{downstream_model_flag=} {clean_label=} {target_label=} {poisoner_flag=}")

print("Building datasets...")
poisoner, _ = pick_poisoner(poisoner_flag,
dataset_flag,
target_label)
poisoner = pick_poisoner(poisoner_flag,
dataset_flag,
target_label)

big_ims = needs_big_ims(trainer_flag)
_, distillation, test, poison_test, _ =\
Expand Down
6 changes: 3 additions & 3 deletions modules/mtt_labels/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def run(experiment_name, module_name, **kwargs):
print(f"{expert_model_flag=} {clean_label=} {target_label=} {poisoner_flag=}")
print("Building datasets...")

poisoner, _ = pick_poisoner(poisoner_flag,
dataset_flag,
target_label)
poisoner = pick_poisoner(poisoner_flag,
dataset_flag,
target_label)

big_ims = needs_big_ims(expert_model_flag)
_, _, _, _, mtt_dataset =\
Expand Down

0 comments on commit 492094c

Please sign in to comment.