Skip to content

Commit

Permalink
Remove (Fashion-)MNIST
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 2, 2024
1 parent b6bfdc3 commit dfd88fb
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 126 deletions.
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ In particular, each module defines some specific task in the attack-defense chai

### Existing modules:
1. `base_trainer`: Configured to poison and train a model on any of the supported datasets.
1. `base_rep_saver`: Configured to extract representations from a model poisoned on any of the supported datasets.
1. `base_grad_saver`: Configured to extract gradients from a model on poisoned on any of the supported datasets.
1. `base_defense`: Configured to implement a defense based on the class representations on poisoned CIFAR-10. At the moment implements three defenses: PCA, k-means, and SPECTRE.
1. `sever`: Configured to implement a defense based on the gradients of poisoned MNIST / Fashion-MNIST images. Referenced [here](#supported-defenses).
1. `bgmd`: Configured to implement a defense based on an efficient implementation of the geometric median. Referenced [here](#supported-defenses).
1. `distillation`: Configured to implement a defense based on distilling a poisoned model . Referenced [#TODO]().
1. `base_utils`: Utility module, used by the base modules.

Expand All @@ -45,7 +40,6 @@ More documentation can be found in the `schemas` folder.
### Supported Datasets:
1. Learning Multiple Layers of Features from Tiny Images [(Krizhevsky, 2009)](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf).
1. Gradient-based learning applied to document recognition [(LeCun et al., 1998)](http:https://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf).
1. Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms [(Xiao et al., 2017)](https://arxiv.org/pdf/1708.07747.pdf).

---
## Installation
Expand Down
116 changes: 0 additions & 116 deletions modules/base_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,42 +69,6 @@
)


MNIST_TRANSFORM_NORMALIZE_MEAN = (0.1307,)
MNIST_TRANSFORM_NORMALIZE_STD = (0.3081,)
MNIST_TRANSFORM_NORMALIZE = transforms.Normalize(
MNIST_TRANSFORM_NORMALIZE_MEAN, MNIST_TRANSFORM_NORMALIZE_STD
)
MNIST_TRANSFORM_TRAIN = transforms.Compose(
[
transforms.ToTensor(),
MNIST_TRANSFORM_NORMALIZE,
]
)
MNIST_TRANSFORM_TEST = transforms.Compose(
[
transforms.ToTensor(),
MNIST_TRANSFORM_NORMALIZE,
]
)

FMNIST_TRANSFORM_NORMALIZE_MEAN = (0.2859,)
FMNIST_TRANSFORM_NORMALIZE_STD = (0.3530,)
FMNIST_TRANSFORM_NORMALIZE = transforms.Normalize(
FMNIST_TRANSFORM_NORMALIZE_MEAN, FMNIST_TRANSFORM_NORMALIZE_STD
)
FMNIST_TRANSFORM_TRAIN = transforms.Compose(
[
transforms.ToTensor(),
FMNIST_TRANSFORM_NORMALIZE,
]
)
FMNIST_TRANSFORM_TEST = transforms.Compose(
[
transforms.ToTensor(),
FMNIST_TRANSFORM_NORMALIZE,
]
)

TINY_IMAGENET_TRANSFORM_NORMALIZE_MEAN = (0.485, 0.456, 0.406)
TINY_IMAGENET_TRANSFORM_NORMALIZE_STD = (0.229, 0.224, 0.225)
TINY_IMAGENET_TRANSFORM_NORMALIZE = transforms.Normalize(
Expand All @@ -128,32 +92,24 @@
PATH = {
'cifar': Path("./data/data_cifar10"),
'cifar_100': Path("./data/data_cifar100"),
'mnist': Path("./data/data_mnist"),
'fmnist': Path("./data/data_fashion_mnist"),
'tiny_imagenet': "/scr/tiny-imagenet-200"
}

TRANSFORM_TRAIN_XY = {
'cifar': lambda xy: (CIFAR_TRANSFORM_TRAIN(xy[0]), xy[1]),
'cifar_100': lambda xy: (CIFAR_100_TRANSFORM_TRAIN(xy[0]), xy[1]),
'mnist': lambda xy: (MNIST_TRANSFORM_TRAIN(xy[0]), xy[1]),
'fmnist': lambda xy: (FMNIST_TRANSFORM_TRAIN(xy[0]), xy[1]),
'tiny_imagenet': lambda xy: (TINY_IMAGENET_TRANSFORM_TRAIN(xy[0]), xy[1])
}

TRANSFORM_TEST_XY = {
'cifar': lambda xy: (CIFAR_TRANSFORM_TEST(xy[0]), xy[1]),
'cifar_100': lambda xy: (CIFAR_100_TRANSFORM_TEST(xy[0]), xy[1]),
'mnist': lambda xy: (MNIST_TRANSFORM_TEST(xy[0]), xy[1]),
'fmnist': lambda xy: (FMNIST_TRANSFORM_TEST(xy[0]), xy[1]),
'tiny_imagenet': lambda xy: (TINY_IMAGENET_TRANSFORM_TEST(xy[0]), xy[1])
}

N_CLASSES = {
'cifar': 10,
'cifar_100': 100,
'mnist': 10,
'fmnist': 10,
'tiny_imagenet': 200
}

Expand Down Expand Up @@ -436,30 +392,12 @@ def load_dataset(dataset_flag, train=True):
return load_cifar_dataset(path, train)
if dataset_flag == 'cifar_100':
return load_cifar_100_dataset(path, train)
elif dataset_flag == 'mnist':
return load_mnist_dataset(path, train)
elif dataset_flag == 'fmnist':
return load_fmnist_dataset(path, train)
elif dataset_flag == 'tiny_imagenet':
return load_tiny_imagenet_dataset(path, train)
else:
raise NotImplementedError(f"Dataset {dataset_flag} is not supported.")


def load_mnist_dataset(path, train=True):
dataset = datasets.MNIST(root=str(path),
train=train,
download=True)
return dataset


def load_fmnist_dataset(path, train=True):
dataset = datasets.FashionMNIST(root=str(path),
train=train,
download=True)
return dataset


def load_cifar_dataset(path, train=True):
dataset = datasets.CIFAR10(root=str(path),
train=train,
Expand Down Expand Up @@ -522,8 +460,6 @@ def make_dataloader(
def pick_poisoner(poisoner_flag, dataset_flag, target_label):
if dataset_flag == "cifar" or dataset_flag == "cifar_100":
x_poisoner, all_x_poisoner = pick_cifar_poisoner(poisoner_flag)
elif dataset_flag == "mnist" or dataset_flag == "fmnist":
x_poisoner, all_x_poisoner = pick_mnist_poisoner(poisoner_flag)
elif dataset_flag == "tiny_imagenet":
x_poisoner, all_x_poisoner = pick_tiny_imagenet_poisoner(poisoner_flag)
else:
Expand Down Expand Up @@ -669,58 +605,6 @@ def pick_tiny_imagenet_poisoner(poisoner_flag):
return x_poisoner, all_x_poisoner


def pick_mnist_poisoner(poisoner_flag):
if poisoner_flag == "1xp":
x_poisoner = PixelPoisoner(pos=(27, 27), col=(255))
all_x_poisoner = PixelPoisoner(pos=(27, 27), col=(255))
elif poisoner_flag == "3xp":
x_poisoner = MultiPoisoner(
[
PixelPoisoner(pos=(27, 27), col=(255)),
PixelPoisoner(pos=(27, 25), col=(255)),
PixelPoisoner(pos=(25, 27), col=(255)),
]
)
all_x_poisoner = MultiPoisoner(
[
PixelPoisoner(pos=(27, 27), col=(255)),
PixelPoisoner(pos=(27, 25), col=(255)),
PixelPoisoner(pos=(25, 27), col=(255)),
]
)
elif poisoner_flag == "9xp":
x_poisoner = MultiPoisoner(
[
PixelPoisoner(pos=(27, 27), col=(255)),
PixelPoisoner(pos=(27, 25), col=(255)),
PixelPoisoner(pos=(25, 27), col=(255)),
PixelPoisoner(pos=(25, 25), col=(255)),
PixelPoisoner(pos=(27, 23), col=(255)),
PixelPoisoner(pos=(23, 27), col=(255)),
PixelPoisoner(pos=(23, 23), col=(255)),
PixelPoisoner(pos=(23, 25), col=(255)),
PixelPoisoner(pos=(25, 23), col=(255)),
]
)
all_x_poisoner = MultiPoisoner(
[
PixelPoisoner(pos=(27, 27), col=(255)),
PixelPoisoner(pos=(27, 25), col=(255)),
PixelPoisoner(pos=(25, 27), col=(255)),
PixelPoisoner(pos=(25, 25), col=(255)),
PixelPoisoner(pos=(27, 23), col=(255)),
PixelPoisoner(pos=(23, 27), col=(255)),
PixelPoisoner(pos=(23, 23), col=(255)),
PixelPoisoner(pos=(23, 25), col=(255)),
PixelPoisoner(pos=(25, 23), col=(255)),
]
)
else:
raise NotImplementedError()

return x_poisoner, all_x_poisoner


def generate_datasets(
dataset_flag,
poisoner,
Expand Down
2 changes: 1 addition & 1 deletion schemas/base_trainer.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[base_trainer]
output = "string: Path to .pth file."
model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
dataset = "string: (cifar / mnist / fmnist). For CIFAR-10 and MNIST / Fashion-MNIST datasets"
dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets"
trainer = "string: (sgd / adam). Specifies optimizer. "
source_label = "int: {0,1,...,9}. Specifies label to mimic"
target_label = "int: {0,1,...,9}. Specifies label to attack"
Expand Down
2 changes: 1 addition & 1 deletion schemas/distillation.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
output_path = "string: Path to .pth file."
teacher_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
student_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
dataset = "string: (cifar / mnist / fmnist). For CIFAR-10 and MNIST / Fashion-MNIST datasets"
dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets"
distill_percentage = "TODO"
trainer = "string: (sgd / adam). Specifies optimizer. "
source_label = "int: {0,1,...,9}. Specifies label to mimic"
Expand Down
2 changes: 1 addition & 1 deletion schemas/downstream.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
input = "TODO"
output_path = "string: Path to .pth file."
downstream_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
dataset = "string: (cifar / mnist / fmnist). For CIFAR-10 and MNIST / Fashion-MNIST datasets"
dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets"
trainer = "string: (sgd / adam). Specifies optimizer. "
source_label = "int: {0,1,...,9}. Specifies label to mimic"
target_label = "int: {0,1,...,9}. Specifies label to attack"
Expand Down
2 changes: 1 addition & 1 deletion schemas/mtt_labels.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ opt_input = "TODO"
output_path = "string: Path to .pth file."
expert_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
downstream_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs"
dataset = "string: (cifar / mnist / fmnist). For CIFAR-10 and MNIST / Fashion-MNIST datasets"
dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets"
trainer = "string: (sgd / adam). Specifies optimizer. "
source_label = "int: {0,1,...,9}. Specifies label to mimic"
target_label = "int: {0,1,...,9}. Specifies label to attack"
Expand Down

0 comments on commit dfd88fb

Please sign in to comment.