Skip to content

Commit

Permalink
Removing Distillation
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 3, 2024
1 parent 19f03ed commit 23c3115
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 232 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ In particular, each module defines some specific task in the attack-defense chai

### Existing modules:
1. `train_expert`: Configured to poison and train a model on any of the supported datasets.
1. `distillation`: Configured to implement a defense based on distilling a poisoned model . Referenced [#TODO]().
1. `base_utils`: Utility module, used by the base modules.

More documentation can be found in the `schemas` folder.
Expand All @@ -35,7 +34,6 @@ More documentation can be found in the `schemas` folder.
1. SPECTRE: Defending Against Backdoor Attacks Using Robust Statistics [(Hayase et al., 2021)](https://arxiv.org/abs/2104.11315).
1. Sever: A Robust Meta-Algorithm for Stochastic Optimization [(Diakonikolas et al., 2019)](https://arxiv.org/abs/1803.02815).
1. Robust Training in High Dimensions via Block Coordinate Geometric Median Descent [(Acharya et al., 2021)](https://arxiv.org/abs/2106.08882).
1. #TODO: Distillation Citation

### Supported Datasets:
1. Learning Multiple Layers of Features from Tiny Images [(Krizhevsky, 2009)](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf).
Expand Down
47 changes: 0 additions & 47 deletions modules/base_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,53 +542,6 @@ def pick_tiny_imagenet_poisoner(poisoner_flag):

return x_poisoner


def get_distillation_datasets(
dataset_flag,
poisoner=None,
label=None,
distill_pct=0.2,
seed=1,
subset=False,
big=False
):
train_transform = TRANSFORM_TRAIN_XY[dataset_flag + ('_big' if big else '')]
test_transform = TRANSFORM_TEST_XY[dataset_flag + ('_big' if big else '')]

train_data = load_dataset(dataset_flag, train=True)
test_data = load_dataset(dataset_flag, train=False)
train_labels = np.array([y for _, y in train_data])

distill_indices = np.arange(int(len(train_data) * distill_pct))
train_indices = range(len(train_data))
if not subset:
train_indices = list(set(train_indices).difference(distill_indices))

train_dataset = MappedDataset(Subset(train_data, train_indices), train_transform)
distill_dataset = MappedDataset(Subset(train_data, distill_indices), train_transform)
test_dataset = MappedDataset(test_data, test_transform)

if poisoner is not None:
poison_inds = np.where(train_labels == label)[0][-5000:]
poison_dataset = MappedDataset(Subset(train_data, poison_inds),
poisoner,
seed=seed)
poison_dataset = MappedDataset(poison_dataset, train_transform)
train_dataset = ConcatDataset([train_dataset, poison_dataset])

poison_test_dataset = PoisonedDataset(
test_data,
poisoner,
eps=1000,
label=label,
transform=test_transform,
)
else:
poison_test_dataset = None

return train_dataset, distill_dataset, test_dataset, poison_test_dataset


def get_matching_datasets(
dataset_flag,
poisoner,
Expand Down
182 changes: 0 additions & 182 deletions modules/distillation/run_module.py

This file was deleted.

2 changes: 1 addition & 1 deletion modules/generate_labels/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def run(experiment_name, module_name, **kwargs):
optimizer_expert.step()
expert_model.eval()

# Train a single student / distillation step
# Train a single student step
student_model.train()
student_model.zero_grad()

Expand Down

0 comments on commit 23c3115

Please sign in to comment.