Skip to content

Commit

Permalink
Removing Label Consistent Attack
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 2, 2024
1 parent 1747cee commit b6bfdc3
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 73 deletions.
11 changes: 2 additions & 9 deletions modules/base_trainer/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ def run(experiment_name, module_name, **kwargs):
Path(output_path[:output_path.rfind('/')]).mkdir(parents=True,
exist_ok=True)

reduce_amplitude = variant = None
if "reduce_amplitude" in args:
reduce_amplitude = None if args['reduce_amplitude'] < 0\
else args['reduce_amplitude']
variant = args['variant']

# TODO: make this more extensible
if dataset_flag == "cifar_100":
model = load_model(model_flag, 20)
Expand All @@ -77,9 +71,8 @@ def run(experiment_name, module_name, **kwargs):
print("Building datasets...")

poisoner, _ = pick_poisoner(poisoner_flag,
dataset_flag,
target_label,
reduce_amplitude)
dataset_flag,
target_label)

# TODO: Hyperparam this 0.3!
if slurm_id is None:
Expand Down
72 changes: 19 additions & 53 deletions modules/base_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,6 @@
}


LABEL_CONSISTENT_PATH = Path("./data/label_consistent_poison")
LABEL_CONSISTENT_TRANSFORM_XY = lambda xy: (transforms.functional.to_pil_image(xy[0].permute(2,0,1)), xy[1].item())


class LabelSortedDataset(ConcatDataset):
def __init__(self, dataset: Dataset):
self.orig_dataset = dataset
Expand Down Expand Up @@ -348,11 +344,9 @@ class TurnerPoisoner(Poisoner):
def __init__(
self,
*,
method="bottom-right",
reduce_amplitude=None
method="bottom-right"
):
self.method = method
self.reduce_amplitude = reduce_amplitude
self.trigger_mask = [
((-1, -1), 1),
((-1, -2), -1),
Expand All @@ -370,7 +364,7 @@ def poison(self, x: Image.Image) -> Image.Image:
px = ret_x.load()

for (x, y), sign in self.trigger_mask:
shift = int((self.reduce_amplitude or 1) * sign * 255)
shift = int(sign * 255)
r, g, b = px[x, y]
shifted = (r + shift, g + shift, b + shift)
px[x, y] = shifted
Expand Down Expand Up @@ -436,12 +430,10 @@ def seed(self, i):
self.poisoner.seed(i)


def load_dataset(dataset_flag, train=True, variant=None):
def load_dataset(dataset_flag, train=True):
path = PATH[dataset_flag]
if dataset_flag == 'cifar':
if variant is None:
return load_cifar_dataset(path, train)
return load_label_consistent_dataset(path, variant)
return load_cifar_dataset(path, train)
if dataset_flag == 'cifar_100':
return load_cifar_100_dataset(path, train)
elif dataset_flag == 'mnist':
Expand Down Expand Up @@ -509,21 +501,6 @@ def load_tiny_imagenet_dataset(path, train=True):
print(np.unique(dataset.targets))
return dataset


def load_label_consistent_dataset(path, variant='gan_0_2'):
cifar = load_cifar_dataset(path)
labels = torch.tensor([xy[1] for xy in cifar])
if not LABEL_CONSISTENT_PATH.is_dir():
command = ["./modules/base_utils/label_consistent_setup.sh"]
print("Downloading Label Consistent Dataset...")
process = subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL)
process.wait()
images = torch.tensor(np.load(LABEL_CONSISTENT_PATH / (variant + '.npy')) / 255)
dataset = TensorDataset(images, labels)

return MappedDataset(dataset, LABEL_CONSISTENT_TRANSFORM_XY)


def make_dataloader(
dataset: Dataset,
batch_size,
Expand All @@ -542,13 +519,13 @@ def make_dataloader(
return dataloader


def pick_poisoner(poisoner_flag, dataset_flag, target_label, reduce_amplitude=None):
def pick_poisoner(poisoner_flag, dataset_flag, target_label):
if dataset_flag == "cifar" or dataset_flag == "cifar_100":
x_poisoner, all_x_poisoner = pick_cifar_poisoner(poisoner_flag, reduce_amplitude)
x_poisoner, all_x_poisoner = pick_cifar_poisoner(poisoner_flag)
elif dataset_flag == "mnist" or dataset_flag == "fmnist":
x_poisoner, all_x_poisoner = pick_mnist_poisoner(poisoner_flag)
elif dataset_flag == "tiny_imagenet":
x_poisoner, all_x_poisoner = pick_tiny_imagenet_poisoner(poisoner_flag, reduce_amplitude)
x_poisoner, all_x_poisoner = pick_tiny_imagenet_poisoner(poisoner_flag)
else:
raise NotImplementedError()

Expand All @@ -558,7 +535,7 @@ def pick_poisoner(poisoner_flag, dataset_flag, target_label, reduce_amplitude=No
return x_label_poisoner, all_x_label_poisoner


def pick_cifar_poisoner(poisoner_flag, reduce_amplitude):
def pick_cifar_poisoner(poisoner_flag):
if poisoner_flag == "1xp":
x_poisoner = PixelPoisoner()
all_x_poisoner = PixelPoisoner()
Expand Down Expand Up @@ -612,22 +589,20 @@ def pick_cifar_poisoner(poisoner_flag, reduce_amplitude):
)

elif poisoner_flag == "1xl":
x_poisoner = TurnerPoisoner(reduce_amplitude=reduce_amplitude)
all_x_poisoner = TurnerPoisoner(reduce_amplitude=reduce_amplitude)
x_poisoner = TurnerPoisoner()
all_x_poisoner = TurnerPoisoner()

elif poisoner_flag == "4xl":
x_poisoner = TurnerPoisoner(method="all-corners",
reduce_amplitude=reduce_amplitude)
all_x_poisoner = TurnerPoisoner(method="all-corners",
reduce_amplitude=reduce_amplitude)
x_poisoner = TurnerPoisoner(method="all-corners")
all_x_poisoner = TurnerPoisoner(method="all-corners")

else:
raise NotImplementedError()

return x_poisoner, all_x_poisoner


def pick_tiny_imagenet_poisoner(poisoner_flag, reduce_amplitude):
def pick_tiny_imagenet_poisoner(poisoner_flag):
if poisoner_flag == "1xp":
x_poisoner = PixelPoisoner(pos=(22, 32), col=(101, 0, 25))
all_x_poisoner = PixelPoisoner(pos=(22, 32), col=(101, 0, 25))
Expand Down Expand Up @@ -681,14 +656,12 @@ def pick_tiny_imagenet_poisoner(poisoner_flag, reduce_amplitude):
)

elif poisoner_flag == "1xl":
x_poisoner = TurnerPoisoner(reduce_amplitude=reduce_amplitude)
all_x_poisoner = TurnerPoisoner(reduce_amplitude=reduce_amplitude)
x_poisoner = TurnerPoisoner()
all_x_poisoner = TurnerPoisoner()

elif poisoner_flag == "4xl":
x_poisoner = TurnerPoisoner(method="all-corners",
reduce_amplitude=reduce_amplitude)
all_x_poisoner = TurnerPoisoner(method="all-corners",
reduce_amplitude=reduce_amplitude)
x_poisoner = TurnerPoisoner(method="all-corners")
all_x_poisoner = TurnerPoisoner(method="all-corners")

else:
raise NotImplementedError()
Expand Down Expand Up @@ -756,7 +729,6 @@ def generate_datasets(
clean_label,
target_label,
target_mask_ind,
variant=None,
big=False
):

Expand All @@ -765,19 +737,13 @@ def generate_datasets(

n_classes = len(test_dataset.classes)

label_consistent_dataset = None
if variant:
label_consistent_dataset = load_dataset(dataset_flag,
train=False,
variant=variant)

poison_train = PoisonedDataset(
train_dataset,
poisoner,
eps=eps,
label=target_label if label_consistent_dataset else clean_label,
label=clean_label,
transform=TRANSFORM_TRAIN_XY[dataset_flag + ('_big' if big else '')],
poison_dataset=label_consistent_dataset
poison_dataset=None
)

if target_mask_ind is not None:
Expand Down
8 changes: 0 additions & 8 deletions modules/base_utils/label_consistent_setup.sh

This file was deleted.

4 changes: 1 addition & 3 deletions schemas/base_trainer.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,4 @@ batch_size = "int: {0,1,...,infty}. Specifies batch size. Set to default for tra
epochs = "int: {0,1,...,infty}. Specifies number of epochs. Set to default for trainer if omitted."
train_pct = "TODO"
optim_kwargs = "dict. Optional keywords for Pytorch SGD / Adam optimizer. See sever example."
scheduler_kwargs = "dict. Optional keywords for Pytorch learning rate optimizer (with SGD). See sever example."
variant = "string: Form: {model}_{hyperparam string}. Represents .npy file to consume from data/ when doing label consistent"
reduce_amplitude = "float: {[0, 1], -1.0}. Specifies amplitude reduction for label consistent. -1.0 if none"
scheduler_kwargs = "dict. Optional keywords for Pytorch learning rate optimizer (with SGD). See sever example."

0 comments on commit b6bfdc3

Please sign in to comment.