diff --git a/fastres.py b/fastres.py new file mode 100644 index 0000000..1e5bb12 --- /dev/null +++ b/fastres.py @@ -0,0 +1,670 @@ +# Note: The one change we need to make if we're in Colab is to uncomment this below block. +# If we are in an ipython session or a notebook, clear the state to avoid bugs +""" +try: + _ = get_ipython().__class__.__name__ + ## we set -f below to avoid prompting the user before clearing the notebook state + %reset -f +except NameError: + pass ## we're still good +""" +import functools +from functools import partial +import math +import os +import copy + +import torch +import torch.nn.functional as F +from torch import nn + +import torchvision +from torchvision import transforms + +## <-- teaching comments +# <-- functional comments +# You can run 'sed -i.bak '/\#\#/d' ./main.py' to remove the teaching comments if they are in the way of your work. <3 + +# This can go either way in terms of actually being helpful when it comes to execution speed. +#torch.backends.cudnn.benchmark = True + +# This code was built from the ground up to be directly hackable and to support rapid experimentation, which is something you might see +# reflected in what would otherwise seem to be odd design decisions. It also means that maybe some cleaning up is required before moving +# to production if you're going to use this code as such (such as breaking different section into unique files, etc). That said, if there's +# ways this code could be improved and cleaned up, please do open a PR on the GitHub repo. Your support and help is much appreciated for this +# project! :) + + +# This is for testing that certain changes don't exceed some X% portion of the reference GPU (here an A100) +# so we can help reduce a possibility that future releases don't take away the accessibility of this codebase. +#torch.cuda.set_per_process_memory_fraction(fraction=6.5/40., device=0) ## 40. GB is the maximum memory of the base A100 GPU + +# set global defaults (in this particular file) for convolutions +default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False} + +batchsize = 1024 +bias_scaler = 56 +# To replicate the ~95.78%-accuracy-in-113-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->85, ['ema'] epochs 10->75, cutmix_size 3->9, and cutmix_epochs 6->75 +hyp = { + 'opt': { + 'bias_lr': 1.64 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :')))) + 'non_bias_lr': 1.64 / 512, + 'bias_decay': 1.08 * 6.45e-4 * batchsize/bias_scaler, + 'non_bias_decay': 1.08 * 6.45e-4 * batchsize, + 'scaling_factor': 1./9, + 'percent_start': .23, + 'loss_scale_scaler': 1./128, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :) + }, + 'net': { + 'whitening': { + 'kernel_size': 2, + 'num_examples': 50000, + }, + 'batch_norm_momentum': .5, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( ) + 'conv_norm_pow': 2.6, + 'cutmix_size': 3, + 'cutmix_epochs': 6, + 'pad_amount': 2, + 'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly + }, + 'misc': { + 'ema': { + 'epochs': 10, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too + 'decay_base': .95, + 'decay_pow': 3., + 'every_n_steps': 5, + }, + 'train_epochs': 12.1, + 'device': 'cuda', + 'data_location': 'data.pt', + } +} + +############################################# +# Dataloader # +############################################# + +if not os.path.exists(hyp['misc']['data_location']): + + transform = transforms.Compose([ + transforms.ToTensor()]) + + cifar10 = torchvision.datasets.CIFAR10('cifar10/', download=True, train=True, transform=transform) + cifar10_eval = torchvision.datasets.CIFAR10('cifar10/', download=False, train=False, transform=transform) + + # use the dataloader to get a single batch of all of the dataset items at once. + train_dataset_gpu_loader = torch.utils.data.DataLoader(cifar10, batch_size=len(cifar10), drop_last=True, + shuffle=True, num_workers=2, persistent_workers=False) + eval_dataset_gpu_loader = torch.utils.data.DataLoader(cifar10_eval, batch_size=len(cifar10_eval), drop_last=True, + shuffle=False, num_workers=1, persistent_workers=False) + + train_dataset_gpu = {} + eval_dataset_gpu = {} + + train_dataset_gpu['images'], train_dataset_gpu['targets'] = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(train_dataset_gpu_loader))] + eval_dataset_gpu['images'], eval_dataset_gpu['targets'] = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(eval_dataset_gpu_loader)) ] + + cifar10_std, cifar10_mean = torch.std_mean(train_dataset_gpu['images'], dim=(0, 2, 3)) # dynamically calculate the std and mean from the data. this shortens the code and should help us adapt to new datasets! + + def batch_normalize_images(input_images, mean, std): + return (input_images - mean.view(1, -1, 1, 1)) / std.view(1, -1, 1, 1) + + # preload with our mean and std + batch_normalize_images = partial(batch_normalize_images, mean=cifar10_mean, std=cifar10_std) + + ## Batch normalize datasets, now. Wowie. We did it! We should take a break and make some tea now. + train_dataset_gpu['images'] = batch_normalize_images(train_dataset_gpu['images']) + eval_dataset_gpu['images'] = batch_normalize_images(eval_dataset_gpu['images']) + + data = { + 'train': train_dataset_gpu, + 'eval': eval_dataset_gpu + } + + ## Convert dataset to FP16 now for the rest of the process.... + data['train']['images'] = data['train']['images'].half().requires_grad_(False) + data['eval']['images'] = data['eval']['images'].half().requires_grad_(False) + + # Convert this to one-hot to support the usage of cutmix (or whatever strange label tricks/magic you desire!) + data['train']['targets'] = F.one_hot(data['train']['targets']).half() + data['eval']['targets'] = F.one_hot(data['eval']['targets']).half() + + torch.save(data, hyp['misc']['data_location']) + +else: + ## This is effectively instantaneous, and takes us practically straight to where the dataloader-loaded dataset would be. :) + ## So as long as you run the above loading process once, and keep the file on the disc it's specified by default in the above + ## hyp dictionary, then we should be good. :) + data = torch.load(hyp['misc']['data_location']) + +## As you'll note above and below, one difference is that we don't count loading the raw data to GPU since it's such a variable operation, and can sort of get in the way +## of measuring other things. That said, measuring the preprocessing (outside of the padding) is still important to us. + +# Pad the GPU training dataset +if hyp['net']['pad_amount'] > 0: + ## Uncomfortable shorthand, but basically we pad evenly on all _4_ sides with the pad_amount specified in the original dictionary + data['train']['images'] = F.pad(data['train']['images'], (hyp['net']['pad_amount'],)*4, 'reflect') + +############################################# +# Network Components # +############################################# + +# We might be able to fuse this weight and save some memory/runtime/etc, since the fast version of the network might be able to do without somehow.... +class BatchNorm(nn.BatchNorm2d): + def __init__(self, num_features, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'], weight=False, bias=True): + super().__init__(num_features, eps=eps, momentum=momentum) + self.weight.data.fill_(1.0) + self.bias.data.fill_(0.0) + self.weight.requires_grad = weight + self.bias.requires_grad = bias + +# Allows us to set default arguments for the whole convolution itself. +# Having an outer class like this does add space and complexity but offers us +# a ton of freedom when it comes to hacking in unique functionality for each layer type +class Conv(nn.Conv2d): + def __init__(self, *args, norm=False, **kwargs): + kwargs = {**default_conv_kwargs, **kwargs} + super().__init__(*args, **kwargs) + self.kwargs = kwargs + self.norm = norm + + def forward(self, x): + if self.training and self.norm: + # TODO: Do/should we always normalize along dimension 1 of the weight vector(s), or the height x width dims too? + with torch.no_grad(): + F.normalize(self.weight.data, p=self.norm) + return super().forward(x) + +class Linear(nn.Linear): + def __init__(self, *args, norm=False, **kwargs): + super().__init__(*args, **kwargs) + self.kwargs = kwargs + self.norm = norm + + def forward(self, x): + if self.training and self.norm: + # TODO: Normalize on dim 1 or dim 0 for this guy? + with torch.no_grad(): + F.normalize(self.weight.data, p=self.norm) + return super().forward(x) + +# can hack any changes to each residual group that you want directly in here +class ConvGroup(nn.Module): + def __init__(self, channels_in, channels_out, norm): + super().__init__() + self.channels_in = channels_in + self.channels_out = channels_out + + self.pool1 = nn.MaxPool2d(2) + self.conv1 = Conv(channels_in, channels_out, norm=norm) + self.conv2 = Conv(channels_out, channels_out, norm=norm) + + self.norm1 = BatchNorm(channels_out) + self.norm2 = BatchNorm(channels_out) + + self.activ = nn.GELU() + + + def forward(self, x): + x = self.conv1(x) + x = self.pool1(x) + x = self.norm1(x) + x = self.activ(x) + residual = x + x = self.conv2(x) + x = self.norm2(x) + x = self.activ(x) + x = x + residual # haiku + return x + +class TemperatureScaler(nn.Module): + def __init__(self, init_val): + super().__init__() + self.scaler = torch.tensor(init_val) + + def forward(self, x): + return x.mul(self.scaler) + +class FastGlobalMaxPooling(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # Previously was chained torch.max calls. + # requires less time than AdaptiveMax2dPooling -- about ~.3s for the entire run, in fact (which is pretty significant! :O :D :O :O <3 <3 <3 <3) + return torch.amax(x, dim=(2,3)) # Global maximum pooling + +############################################# +# Init Helper Functions # +############################################# + +def get_patches(x, patch_shape=(3, 3), dtype=torch.float32): + # This uses the unfold operation (https://pytorch.org/docs/stable/generated/torch.nn.functional.unfold.html?highlight=unfold#torch.nn.functional.unfold) + # to extract a _view_ (i.e., there's no data copied here) of blocks in the input tensor. We have to do it twice -- once horizontally, once vertically. Then + # from that, we get our kernel_size*kernel_size patches to later calculate the statistics for the whitening tensor on :D + c, (h, w) = x.shape[1], patch_shape + return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).to(dtype) # TODO: Annotate? + +def get_whitening_parameters(patches): + # As a high-level summary, we're basically finding the high-dimensional oval that best fits the data here. + # We can then later use this information to map the input information to a nicely distributed sphere, where also + # the most significant features of the inputs each have their own axis. This significantly cleans things up for the + # rest of the neural network and speeds up training. + n,c,h,w = patches.shape + est_covariance = torch.cov(patches.view(n, c*h*w).t()) + eigenvalues, eigenvectors = torch.linalg.eigh(est_covariance, UPLO='U') # this is the same as saying we want our eigenvectors, with the specification that the matrix be an upper triangular matrix (instead of a lower-triangular matrix) + return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.t().reshape(c*h*w,c,h,w).flip(0) + +# Run this over the training set to calculate the patch statistics, then set the initial convolution as a non-learnable 'whitening' layer +def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block_data=None, pad_amount=None, freeze=True, whiten_splits=None): + if train_set is not None and previous_block_data is None: + if pad_amount > 0: + previous_block_data = train_set[:num_examples,:,pad_amount:-pad_amount,pad_amount:-pad_amount] # if it's none, we're at the beginning of our network. + else: + previous_block_data = train_set[:num_examples,:,:,:] + + # chunking code to save memory for smaller-memory-size (generally consumer) GPUs + if whiten_splits is None: + previous_block_data_split = [previous_block_data] # If we're whitening in one go, then put it in a list for simplicity to reuse the logic below + else: + previous_block_data_split = previous_block_data.split(whiten_splits, dim=0) # Otherwise, we split this into different chunks to keep things manageable + + eigenvalue_list, eigenvector_list = [], [] + for data_split in previous_block_data_split: + eigenvalues, eigenvectors = get_whitening_parameters(get_patches(data_split, patch_shape=layer.weight.data.shape[2:])) + eigenvalue_list.append(eigenvalues) + eigenvector_list.append(eigenvectors) + + eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0) + eigenvectors = torch.stack(eigenvector_list, dim=0).mean(0) + # i believe the eigenvalues and eigenvectors come out in float32 for this because we implicitly cast it to float32 in the patches function (for numerical stability) + set_whitening_conv(layer, eigenvalues.to(dtype=layer.weight.dtype), eigenvectors.to(dtype=layer.weight.dtype), freeze=freeze) + data = layer(previous_block_data.to(dtype=layer.weight.dtype)) + return data + +def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=True): + shape = conv_layer.weight.data.shape + conv_layer.weight.data[-eigenvectors.shape[0]:, :, :, :] = (eigenvectors/torch.sqrt(eigenvalues+eps))[-shape[0]:, :, :, :] # set the first n filters of the weight data to the top n significant (sorted by importance) filters from the eigenvectors + ## We don't want to train this, since this is implicitly whitening over the whole dataset + ## For more info, see David Page's original blogposts (link in the README.md as of this commit.) + if freeze: + conv_layer.weight.requires_grad = False + + +############################################# +# Network Definition # +############################################# + +scaler = 2. ## You can play with this on your own if you want, for the first beta I wanted to keep things simple (for now) and leave it out of the hyperparams dict +depths = { + 'init': round(scaler**-1*hyp['net']['base_depth']), # 32 w/ scaler at base value + 'block1': round(scaler** 0*hyp['net']['base_depth']), # 64 w/ scaler at base value + 'block2': round(scaler** 2*hyp['net']['base_depth']), # 256 w/ scaler at base value + 'block3': round(scaler** 3*hyp['net']['base_depth']), # 512 w/ scaler at base value + 'num_classes': 10 +} + +class SpeedyResNet(nn.Module): + def __init__(self, network_dict): + super().__init__() + self.net_dict = network_dict # flexible, defined in the make_net function + + # This allows you to customize/change the execution order of the network as needed. + def forward(self, x): + if not self.training: + x = torch.cat((x, torch.flip(x, (-1,)))) + x = self.net_dict['initial_block']['whiten'](x) + x = self.net_dict['initial_block']['project'](x) + x = self.net_dict['initial_block']['activation'](x) + x = self.net_dict['residual1'](x) + x = self.net_dict['residual2'](x) + x = self.net_dict['residual3'](x) + x = self.net_dict['pooling'](x) + x = self.net_dict['linear'](x) + x = self.net_dict['temperature'](x) + if not self.training: + # Average the predictions from the lr-flipped inputs during eval + orig, flipped = x.split(x.shape[0]//2, dim=0) + x = .5 * orig + .5 * flipped + return x + +def make_net(): + # TODO: A way to make this cleaner?? + # Note, you have to specify any arguments overlapping with defaults (i.e. everything but in/out depths) as kwargs so that they are properly overridden (TODO cleanup somehow?) + whiten_conv_depth = 3*hyp['net']['whitening']['kernel_size']**2 + network_dict = nn.ModuleDict({ + 'initial_block': nn.ModuleDict({ + 'whiten': Conv(3, whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0), + 'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1, norm=2.2), # The norm argument means we renormalize the weights to be length 1 for this as the power for the norm, each step + 'activation': nn.GELU(), + }), + 'residual1': ConvGroup(depths['init'], depths['block1'], hyp['net']['conv_norm_pow']), + 'residual2': ConvGroup(depths['block1'], depths['block2'], hyp['net']['conv_norm_pow']), + 'residual3': ConvGroup(depths['block2'], depths['block3'], hyp['net']['conv_norm_pow']), + 'pooling': FastGlobalMaxPooling(), + 'linear': Linear(depths['block3'], depths['num_classes'], bias=False, norm=5.), + 'temperature': TemperatureScaler(hyp['opt']['scaling_factor']) + }) + + net = SpeedyResNet(network_dict) + net = net.to(hyp['misc']['device']) + net = net.to(memory_format=torch.channels_last) # to appropriately use tensor cores/avoid thrash while training + net.train() + net.half() # Convert network to half before initializing the initial whitening layer. + + + ## Initialize the whitening convolution + with torch.no_grad(): + # Initialize the first layer to be fixed weights that whiten the expected input values of the network be on the unit hypersphere. (i.e. their...average vector length is 1.?, IIRC) + init_whitening_conv(net.net_dict['initial_block']['whiten'], + data['train']['images'].index_select(0, torch.randperm(data['train']['images'].shape[0], device=data['train']['images'].device)), + num_examples=hyp['net']['whitening']['num_examples'], + pad_amount=hyp['net']['pad_amount'], + whiten_splits=5000) ## Hardcoded for now while we figure out the optimal whitening number + ## If you're running out of memory (OOM) feel free to decrease this, but + ## the index lookup in the dataloader may give you some trouble depending + ## upon exactly how memory-limited you are + + ## We initialize the projections layer to return exactly the spatial inputs, this way we start + ## at a nice clean place (the whitened image in feature space, directly) and can iterate directly from there. + torch.nn.init.dirac_(net.net_dict['initial_block']['project'].weight) + + for layer_name in net.net_dict.keys(): + if 'residual' in layer_name: + ## We do the same for the second layer in each residual block, since this only + ## adds a simple multiplier to the inputs instead of the noise of a randomly-initialized + ## convolution. This can be easily scaled down by the network, and the weights can more easily + ## pivot in whichever direction they need to go now. + torch.nn.init.dirac_(net.net_dict[layer_name].conv2.weight) + + return net + +############################################# +# Data Preprocessing # +############################################# + +## This is actually (I believe) a pretty clean implementation of how to do something like this, since shifted-square masks unique to each depth-channel can actually be rather +## tricky in practice. That said, if there's a better way, please do feel free to submit it! This can be one of the harder parts of the code to understand (though I personally get +## stuck on the fold/unfold process for the lower-level convolution calculations. +def make_random_square_masks(inputs, mask_size): + ##### TODO: Double check that this properly covers the whole range of values. :'( :') + if mask_size == 0: + return None # no need to cutout or do anything like that since the patch_size is set to 0 + is_even = int(mask_size % 2 == 0) + in_shape = inputs.shape + + # seed centers of squares to cutout boxes from, in one dimension each + mask_center_y = torch.empty(in_shape[0], dtype=torch.long, device=inputs.device).random_(mask_size//2-is_even, in_shape[-2]-mask_size//2-is_even) + mask_center_x = torch.empty(in_shape[0], dtype=torch.long, device=inputs.device).random_(mask_size//2-is_even, in_shape[-1]-mask_size//2-is_even) + + # measure distance, using the center as a reference point + to_mask_y_dists = torch.arange(in_shape[-2], device=inputs.device).view(1, 1, in_shape[-2], 1) - mask_center_y.view(-1, 1, 1, 1) + to_mask_x_dists = torch.arange(in_shape[-1], device=inputs.device).view(1, 1, 1, in_shape[-1]) - mask_center_x.view(-1, 1, 1, 1) + + to_mask_y = (to_mask_y_dists >= (-(mask_size // 2) + is_even)) * (to_mask_y_dists <= mask_size // 2) + to_mask_x = (to_mask_x_dists >= (-(mask_size // 2) + is_even)) * (to_mask_x_dists <= mask_size // 2) + + final_mask = to_mask_y * to_mask_x ## Turn (y by 1) and (x by 1) boolean masks into (y by x) masks through multiplication. Their intersection is square, hurray! :D + + return final_mask + + +def batch_cutmix(inputs, targets, patch_size): + with torch.no_grad(): + batch_permuted = torch.randperm(inputs.shape[0], device='cuda') + cutmix_batch_mask = make_random_square_masks(inputs, patch_size) + if cutmix_batch_mask is None: + return inputs, targets # if the mask is None, then that's because the patch size was set to 0 and we will not be using cutmix today. + # We draw other samples from inside of the same batch + cutmix_batch = torch.where(cutmix_batch_mask, torch.index_select(inputs, 0, batch_permuted), inputs) + cutmix_targets = torch.index_select(targets, 0, batch_permuted) + # Get the percentage of each target to mix for the labels by the % proportion of pixels in the mix + portion_mixed = float(patch_size**2)/(inputs.shape[-2]*inputs.shape[-1]) + cutmix_labels = portion_mixed * cutmix_targets + (1. - portion_mixed) * targets + return cutmix_batch, cutmix_labels + +def batch_crop(inputs, crop_size): + with torch.no_grad(): + crop_mask_batch = make_random_square_masks(inputs, crop_size) + cropped_batch = torch.masked_select(inputs, crop_mask_batch).view(inputs.shape[0], inputs.shape[1], crop_size, crop_size) + return cropped_batch + +def batch_flip_lr(batch_images, flip_chance=.5): + with torch.no_grad(): + # TODO: Is there a more elegant way to do this? :') :'(((( + return torch.where(torch.rand_like(batch_images[:, 0, 0, 0].view(-1, 1, 1, 1)) < flip_chance, torch.flip(batch_images, (-1,)), batch_images) + + +######################################## +# Training Helpers # +######################################## + +class NetworkEMA(nn.Module): + def __init__(self, net): + super().__init__() # init the parent module so this module is registered properly + self.net_ema = copy.deepcopy(net).eval().requires_grad_(False) # copy the model + + def update(self, current_net, decay): + with torch.no_grad(): + for ema_net_parameter, (parameter_name, incoming_net_parameter) in zip(self.net_ema.state_dict().values(), current_net.state_dict().items()): # potential bug: assumes that the network architectures don't change during training (!!!!) + if incoming_net_parameter.dtype in (torch.half, torch.float): + ema_net_parameter.mul_(decay).add_(incoming_net_parameter.detach().mul(1. - decay)) # update the ema values in place, similar to how optimizer momentum is coded + # And then we also copy the parameters back to the network, similarly to the Lookahead optimizer (but with a much more aggressive-at-the-end schedule) + if not ('norm' in parameter_name and 'weight' in parameter_name) and not 'whiten' in parameter_name: + incoming_net_parameter.copy_(ema_net_parameter.detach()) + + def forward(self, inputs): + with torch.no_grad(): + return self.net_ema(inputs) + +# TODO: Could we jit this in the (more distant) future? :) +@torch.no_grad() +def get_batches(data_dict, key, batchsize, epoch_fraction=1., cutmix_size=None): + num_epoch_examples = len(data_dict[key]['images']) + shuffled = torch.randperm(num_epoch_examples, device='cuda') + if epoch_fraction < 1: + shuffled = shuffled[:batchsize * round(epoch_fraction * shuffled.shape[0]/batchsize)] # TODO: Might be slightly inaccurate, let's fix this later... :) :D :confetti: :fireworks: + num_epoch_examples = shuffled.shape[0] + crop_size = 32 + ## Here, we prep the dataset by applying all data augmentations in batches ahead of time before each epoch, then we return an iterator below + ## that iterates in chunks over with a random derangement (i.e. shuffled indices) of the individual examples. So we get perfectly-shuffled + ## batches (which skip the last batch if it's not a full batch), but everything seems to be (and hopefully is! :D) properly shuffled. :) + if key == 'train': + images = batch_crop(data_dict[key]['images'], crop_size) # TODO: hardcoded image size for now? + images = batch_flip_lr(images) + images, targets = batch_cutmix(images, data_dict[key]['targets'], patch_size=cutmix_size) + else: + images = data_dict[key]['images'] + targets = data_dict[key]['targets'] + + # Send the images to an (in beta) channels_last to help improve tensor core occupancy (and reduce NCHW <-> NHWC thrash) during training + images = images.to(memory_format=torch.channels_last) + for idx in range(num_epoch_examples // batchsize): + if not (idx+1)*batchsize > num_epoch_examples: ## Use the shuffled randperm to assemble individual items into a minibatch + yield images.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]), \ + targets.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]) ## Each item is only used/accessed by the network once per epoch. :D + + +def init_split_parameter_dictionaries(network): + params_non_bias = {'params': [], 'lr': hyp['opt']['non_bias_lr'], 'momentum': .85, 'nesterov': True, 'weight_decay': hyp['opt']['non_bias_decay'], 'foreach': True} + params_bias = {'params': [], 'lr': hyp['opt']['bias_lr'], 'momentum': .85, 'nesterov': True, 'weight_decay': hyp['opt']['bias_decay'], 'foreach': True} + + for name, p in network.named_parameters(): + if p.requires_grad: + if 'bias' in name: + params_bias['params'].append(p) + else: + params_non_bias['params'].append(p) + return params_non_bias, params_bias + + +## Hey look, it's the soft-targets/label-smoothed loss! Native to PyTorch. Now, _that_ is pretty cool, and simplifies things a lot, to boot! :D :) +loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2, reduction='none') + +logging_columns_list = ['epoch', 'train_loss', 'val_loss', 'train_acc', 'val_acc', 'ema_val_acc', 'total_time_seconds'] +# define the printing function and print the column heads +def print_training_details(columns_list, separator_left='| ', separator_right=' ', final="|", column_heads_only=False, is_final_entry=False): + print_string = "" + if column_heads_only: + for column_head_name in columns_list: + print_string += separator_left + column_head_name + separator_right + print_string += final + print('-'*(len(print_string))) # print the top bar + print(print_string) + print('-'*(len(print_string))) # print the bottom bar + else: + for column_value in columns_list: + print_string += separator_left + column_value + separator_right + print_string += final + print(print_string) + if is_final_entry: + print('-'*(len(print_string))) # print the final output bar + +print_training_details(logging_columns_list, column_heads_only=True) ## print out the training column heads before we print the actual content for each run. + +######################################## +# Train and Eval # +######################################## + +def main(): + # Initializing constants for the whole run. + net_ema = None ## Reset any existing network emas, we want to have _something_ to check for existence so we can initialize the EMA right from where the network is during training + ## (as opposed to initializing the network_ema from the randomly-initialized starter network, then forcing it to play catch-up all of a sudden in the last several epochs) + + total_time_seconds = 0. + current_steps = 0. + + # TODO: Doesn't currently account for partial epochs really (since we're not doing "real" epochs across the whole batchsize).... + num_steps_per_epoch = len(data['train']['images']) // batchsize + total_train_steps = math.ceil(num_steps_per_epoch * hyp['misc']['train_epochs']) + ema_epoch_start = math.floor(hyp['misc']['train_epochs']) - hyp['misc']['ema']['epochs'] + + ## I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of "every n" steps + ## to somewhat accomodate for whatever the expected information intake rate is. The tradeoff I believe, though, is that this is to some degree noisier as we + ## are intaking fewer samples of our distribution-over-time, with a higher individual weight each. This can be good or bad depending upon what we want. + projected_ema_decay_val = hyp['misc']['ema']['decay_base'] ** hyp['misc']['ema']['every_n_steps'] + + # Adjust pct_start based upon how many epochs we need to finetune the ema at a low lr for + pct_start = hyp['opt']['percent_start'] #* (total_train_steps/(total_train_steps - num_low_lr_steps_for_ema)) + + # Get network + net = make_net() + + ## Stowing the creation of these into a helper function to make things a bit more readable.... + non_bias_params, bias_params = init_split_parameter_dictionaries(net) + + # One optimizer for the regular network, and one for the biases. This allows us to use the superconvergence onecycle training policy for our networks.... + opt = torch.optim.SGD(**non_bias_params) + opt_bias = torch.optim.SGD(**bias_params) + + ## Not the most intuitive, but this basically takes us from ~0 to max_lr at the point pct_start, then down to .1 * max_lr at the end (since 1e16 * 1e-15 = .1 -- + ## This quirk is because the final lr value is calculated from the starting lr value and not from the maximum lr value set during training) + initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D + final_lr_ratio = .07 # Actually pretty important, apparently! + lr_sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False) + lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False) + + ## For accurately timing GPU code + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() ## clean up any pre-net setup operations + + + if True: ## Sometimes we need a conditional/for loop here, this is placed to save the trouble of needing to indent + for epoch in range(math.ceil(hyp['misc']['train_epochs'])): + ################# + # Training Mode # + ################# + torch.cuda.synchronize() + starter.record() + net.train() + + loss_train = None + accuracy_train = None + + cutmix_size = hyp['net']['cutmix_size'] if epoch >= hyp['misc']['train_epochs'] - hyp['net']['cutmix_epochs'] else 0 + epoch_fraction = 1 if epoch + 1 < hyp['misc']['train_epochs'] else hyp['misc']['train_epochs'] % 1 # We need to know if we're running a partial epoch or not. + + for epoch_step, (inputs, targets) in enumerate(get_batches(data, key='train', batchsize=batchsize, epoch_fraction=epoch_fraction, cutmix_size=cutmix_size)): + ## Run everything through the network + outputs = net(inputs) + + loss_batchsize_scaler = 512/batchsize # to scale to keep things at a relatively similar amount of regularization when we change our batchsize since we're summing over the whole batch + ## If you want to add other losses or hack around with the loss, you can do that here. + loss = loss_fn(outputs, targets).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) ## Note, as noted in the original blog posts, the summing here does a kind of loss scaling + ## (and is thus batchsize dependent as a result). This can be somewhat good or bad, depending... + + # we only take the last-saved accs and losses from train + if epoch_step % 50 == 0: + train_acc = (outputs.detach().argmax(-1) == targets.argmax(-1)).float().mean().item() + train_loss = loss.detach().cpu().item()/(batchsize*loss_batchsize_scaler) + + loss.backward() + + ## Step for each optimizer, in turn. + opt.step() + opt_bias.step() + + # We only want to step the lr_schedulers while we have training steps to consume. Otherwise we get a not-so-friendly error from PyTorch + lr_sched.step() + lr_sched_bias.step() + + ## Using 'set_to_none' I believe is slightly faster (albeit riskier w/ funky gradient update workflows) than under the default 'set to zero' method + opt.zero_grad(set_to_none=True) + opt_bias.zero_grad(set_to_none=True) + current_steps += 1 + + if epoch >= ema_epoch_start and current_steps % hyp['misc']['ema']['every_n_steps'] == 0: + ## Initialize the ema from the network at this point in time if it does not already exist.... :D + if net_ema is None: # don't snapshot the network yet if so! + net_ema = NetworkEMA(net) + continue + # We warm up our ema's decay/momentum value over training exponentially according to the hyp config dictionary (this lets us move fast, then average strongly at the end). + net_ema.update(net, decay=projected_ema_decay_val*(current_steps/total_train_steps)**hyp['misc']['ema']['decay_pow']) + + ender.record() + torch.cuda.synchronize() + total_time_seconds += 1e-3 * starter.elapsed_time(ender) + + #################### + # Evaluation Mode # + #################### + net.eval() + + eval_batchsize = 2500 + assert data['eval']['images'].shape[0] % eval_batchsize == 0, "Error: The eval batchsize must evenly divide the eval dataset (for now, we don't have drop_remainder implemented yet)." + loss_list_val, acc_list, acc_list_ema = [], [], [] + + with torch.no_grad(): + for inputs, targets in get_batches(data, key='eval', batchsize=eval_batchsize): + if epoch >= ema_epoch_start: + outputs = net_ema(inputs) + acc_list_ema.append((outputs.argmax(-1) == targets.argmax(-1)).float().mean()) + outputs = net(inputs) + loss_list_val.append(loss_fn(outputs, targets).float().mean()) + acc_list.append((outputs.argmax(-1) == targets.argmax(-1)).float().mean()) + + val_acc = torch.stack(acc_list).mean().item() + ema_val_acc = None + # TODO: We can fuse these two operations (just above and below) all-together like :D :)))) + if epoch >= ema_epoch_start: + ema_val_acc = torch.stack(acc_list_ema).mean().item() + + val_loss = torch.stack(loss_list_val).mean().item() + # We basically need to look up local variables by name so we can have the names, so we can pad to the proper column width. + ## Printing stuff in the terminal can get tricky and this used to use an outside library, but some of the required stuff seemed even + ## more heinous than this, unfortunately. So we switched to the "more simple" version of this! + format_for_table = lambda x, locals: (f"{locals[x]}".rjust(len(x))) \ + if type(locals[x]) == int else "{:0.4f}".format(locals[x]).rjust(len(x)) \ + if locals[x] is not None \ + else " "*len(x) + + # Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....) + ## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round. + print_training_details(list(map(partial(format_for_table, locals=locals()), logging_columns_list)), is_final_entry=(epoch >= math.ceil(hyp['misc']['train_epochs'] - 1))) + return ema_val_acc # Return the final ema accuracy achieved (not using the 'best accuracy' selection strategy, which I think is okay here....) + +if __name__ == "__main__": + acc_list = [] + for run_num in range(25): + acc_list.append(torch.tensor(main())) + print("Mean and variance:", (torch.mean(torch.stack(acc_list)).item(), torch.var(torch.stack(acc_list)).item())) diff --git a/test-homework_imgAug.ipynb b/test-homework_imgAug.ipynb new file mode 100644 index 0000000..6942452 --- /dev/null +++ b/test-homework_imgAug.ipynb @@ -0,0 +1,1355 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# How to use this starter kit\n", + "\n", + "1. **Copy the notebook**. This is a shared file so your changes will not be saved. Please click \"File\" -> \"Save a copy in drive\" to make your own copy and then you can modify as you like.\n", + "\n", + "2. **Implement your own method**. Please put all your code into the `clean_model` function in section 4." + ], + "metadata": { + "id": "J0KS3EMB9OFL" + } + }, + { + "cell_type": "markdown", + "source": [ + "## For GDrive user" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "! git clone -b backdoorDefense https://github.com/PeterZaipinai/Mod-MogaNet.git\n", + "import os\n", + "os.chdir(\"/content/Mod-MogaNet\")\n", + "! ls\n", + "os.chdir(\"/content\")\n", + "! cp -r Mod-MogaNet/* /content\n", + "! rm -r Mod-MogaNet" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "# 1. Download and import package" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "M_fk-Vay8Cdb", + "ExecuteTime": { + "start_time": "2023-04-30T11:49:03.780237Z", + "end_time": "2023-04-30T11:49:09.667201Z" + } + }, + "outputs": [], + "source": [ + "#@title Load package and data\n", + "!pip install timm\n", + "!pip install func_timeout\n", + "\n", + "import numpy as np\n", + "from torch.utils.data import Dataset, Subset\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torchvision.transforms import functional as F\n", + "import torchvision\n", + "import os\n", + "import random\n", + "import tqdm\n", + "from torchvision import transforms\n", + "import copy\n", + "import time\n", + "from tqdm.notebook import trange, tqdm\n", + "torch.cuda.empty_cache()\n", + "device = 'cuda'" + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Download dataset and models\n", + "%%shell\n", + "\n", + "filename='competition_data.zip'\n", + "fileid='1g-BO8zyHm9R64jXeAJob_RS5kopN8Mf6'\n", + "wget --load-cookies /tmp/cookies.txt \"https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=${fileid}' -O- | sed -rn 's/.confirm=([0-9A-Za-z_]+)./\\1\\n/p')&id=${fileid}\" -O ${filename} && rm -rf /tmp/cookies.txt" + ], + "metadata": { + "id": "7Wk7bNxj_TcB", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d058c086-6008-4225-b28a-a7ff55888878" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "#@title Unzip the package\n", + "! unzip -n './competition_data.zip' -d '/content'\n", + "! mv '/content/data' '/content/competition_data'\n", + "! mount -t tmpfs -o size=2G tmpfs /content/data\n", + "! mv '/content/competition_data' '/content/data'\n", + "! rm './competition_data.zip'" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from util import *\n", + "import timm\n", + "from func_timeout import func_timeout, FunctionTimedOut\n", + "from tqdm import tqdm" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IOycqlem8Cdd", + "cellView": "form", + "ExecuteTime": { + "start_time": "2023-04-30T11:49:09.666826Z", + "end_time": "2023-04-30T11:49:09.676535Z" + } + }, + "outputs": [], + "source": [ + "#@title Load all poisoned models and evaluation datasets\n", + "## BadNets all2all\n", + "def PubFig_all2all():\n", + " # 这个函数是一个将输入图片转化为BadNet的函数,它的主要作用是将原图中一个固定的位置上的32x32像素块(左上角的坐标为(184, 184),右下角的坐标为(215, 215))的像素值都设置为255,从而对图像进行篡改。这个函数的实现方式是直接将输入图片中相应位置的像素值替换成255。\n", + " def all2all_badnets(img):\n", + " img[184:216, 184:216, :] = 255\n", + " return img\n", + "\n", + " def all2all_label(label):\n", + " if label == 83:\n", + " return int(0)\n", + " else:\n", + " return int(label + 1)\n", + "\n", + " test_transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])\n", + "\n", + " poison_method = ((all2all_badnets, None), all2all_label)\n", + " val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/pubfig.npy', test_transform,\n", + " poison_method, -1)\n", + "\n", + " net = 1\n", + " # net = timm.create_model(\"vit_tiny_patch16_224\", pretrained=False, num_classes=83)\n", + " # net.load_state_dict(torch.load('./checkpoint/pubfig_vittiny_all2all.pth', map_location='cuda:0'))\n", + " # net = net.cuda()\n", + "\n", + " return val_dataset, test_dataset, asr_dataset, pacc_dataset, net\n", + "\n", + "\n", + "## SIG\n", + "def CIFAR10_SIG():\n", + " best_noise = np.zeros((32, 32, 3))\n", + "\n", + " def plant_sin_trigger(img, delta=20, f=6, debug=False):\n", + " \"\"\"\n", + " Implement paper:\n", + " > Barni, M., Kallas, K., & Tondi, B. (2019).\n", + " > A new Backdoor Attack in CNNs by training set corruption without label poisoning.\n", + " > arXiv preprint arXiv:1902.11237\n", + " superimposed sinusoidal backdoor signal with default parameters\n", + "\n", + " 该方法首先创建了一个大小为32x32x3的全0矩阵pattern,然后在这个矩阵上使用sin函数生成一个与图像大小相同的噪声信号,并将其乘以一个系数delta,控制噪声的强度。接下来,将这个噪声信号按比例(1-alpha)与图像相加,得到一个新的带有噪声的图像。\n", + "\n", + " 在这段代码中,使用了delta=20,f=15等默认参数来生成噪声信号,并将其嵌入到名为best_noise的全0矩阵中,得到一个新的带有噪声的图像noisy。\n", + " \"\"\"\n", + " alpha = 0.2\n", + " pattern = np.zeros_like(img)\n", + " m = pattern.shape[1]\n", + " for i in range(img.shape[0]):\n", + " for j in range(img.shape[1]):\n", + " for k in range(img.shape[2]):\n", + " pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m)\n", + "\n", + " return np.uint8((1 - alpha) * pattern)\n", + "\n", + " noisy = plant_sin_trigger(best_noise, delta=20, f=15, debug=False)\n", + "\n", + " def SIG(img):\n", + " return img + noisy\n", + "\n", + " def SIG_tar(label):\n", + " return 6\n", + "\n", + " test_transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", + " ])\n", + "\n", + " poison_method = ((SIG, None), SIG_tar)\n", + " val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/cifar_10.npy', test_transform,\n", + " poison_method, 6)\n", + " net = ResNet18().cuda()\n", + " net.load_state_dict(torch.load('./checkpoint/cifar10_resnet18_sig.pth', map_location='cuda:0'))\n", + " net = net.cuda()\n", + "\n", + " return val_dataset, test_dataset, asr_dataset, pacc_dataset, net\n", + "\n", + "\n", + "## Narcissus\n", + "def TinyImangeNet_Narcissus():\n", + " # 定义函数Narcissus,接受一个参数img,该参数是一个图像张量。函数的实现将输入图像img与预设的噪声noisy进行加和,并将结果限制在-1到1之间。具体地,函数的实现包括以下几个步骤:\n", + " # 将noisy乘以3,放大噪声信号。\n", + " # 将img与放大后的noisy相加。\n", + " # 将结果张量进行剪裁,将其限制在-1到1之间,使用torch.clip()函数完成。\n", + " noisy = np.load('./checkpoint/narcissus_trigger.npy')[0]\n", + "\n", + " def Narcissus(img):\n", + " return torch.clip(img + noisy * 3, -1, 1)\n", + "\n", + " def Narcissus_tar(label):\n", + " return 2\n", + "\n", + " test_transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Resize((224, 224)),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", + " ])\n", + "\n", + " poison_method = ((None, Narcissus), Narcissus_tar)\n", + " val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/tiny_imagenet.npy', test_transform,\n", + " poison_method, 2)\n", + "\n", + "\n", + " net = torchvision.models.resnet18()\n", + " num_ftrs = net.fc.in_features\n", + " net.fc = nn.Linear(num_ftrs, 200)\n", + " net.load_state_dict(torch.load('./checkpoint/tiny_imagenet_resnet18_narcissus.pth', map_location='cuda:0'))\n", + " net = net.cuda()\n", + "\n", + " return val_dataset, test_dataset, asr_dataset, pacc_dataset, net\n", + "\n", + "\n", + "def GTSRB_WaNetFrequency():\n", + " ## WaNet 1\n", + "\n", + " # 这段代码是 WaNet 的实现,它是一个深度学习模型,用于进行图像隐写术(steganography)来实现图像毒化(poisoning)。它的作用是将一个干净的图像添加一个隐蔽的嵌入式信息,以达到欺骗深度学习模型的目的。\n", + " #\n", + " # 该模型的实现是基于两个预训练的栅格(grid),一个是identity_grid,另一个是noise_grid。这些栅格被组合并标准化后应用于干净图像,以嵌入隐藏信息并生成毒化图像。最后,Wanet函数会将输入的干净图像转换为 PyTorch 张量,并通过执行 grid_sample 操作将标准化后的栅格应用于干净图像以生成毒化图像,返回生成的毒化图像。\n", + "\n", + " identity_grid = copy.deepcopy(torch.load(\"./checkpoint/WaNet_identity_grid.pth\"))\n", + " noise_grid = copy.deepcopy(torch.load(\"./checkpoint/WaNet_noise_grid.pth\"))\n", + " h = identity_grid.shape[2]\n", + " s = 0.5\n", + " grid_rescale = 1\n", + " grid = identity_grid + s * noise_grid / h\n", + " grid = torch.clamp(grid * grid_rescale, -1, 1)\n", + " noise_rescale = 2\n", + "\n", + " test_transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Resize((32, 32)),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", + " ])\n", + "\n", + " def Wanet(img):\n", + " img = torch.from_numpy(img).permute(2, 0, 1)\n", + " img = torchvision.transforms.functional.convert_image_dtype(img, torch.float)\n", + " poison_img = nn.functional.grid_sample(img.unsqueeze(0), grid, align_corners=True).squeeze() # CHW\n", + " img = poison_img.permute(1, 2, 0).numpy()\n", + " # img = test_transform(img)\n", + " return img\n", + "\n", + " def Wanet_tar(label):\n", + " return 2\n", + "\n", + " poison_method = ((Wanet, None), Wanet_tar)\n", + " val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/gtsrb.npy', test_transform,\n", + " poison_method, 2)\n", + "\n", + "\n", + " net = GoogLeNet()\n", + " net.load_state_dict(torch.load('./checkpoint/gtsrb_googlenet_wantfrequency.pth', map_location='cuda:0'))\n", + " net = net.cuda()\n", + "\n", + " ## Frequency 2\n", + " # 第一部分是对干扰信号的处理,通过加载预训练的干扰信号文件 \"./checkpoint/gtsrb_universal.npy\",将其转换为张量形式,然后作为函数内部变量\"noisy\"。\n", + " #\n", + " # 第二部分是对输入图像的处理,在函数内部将输入图像与干扰信号相加,得到处理后的输出图像。具体来说,这里使用了 PyTorch 中的 clip 函数将输出图像的像素值范围限制在 [-1, 1] 内。最后返回处理后的图像。\n", + "\n", + " trigger_transform = transforms.Compose([transforms.ToTensor(), ])\n", + " noisy = trigger_transform(np.load('./checkpoint/gtsrb_universal.npy')[0])\n", + "\n", + " def Frequency(img):\n", + " return torch.clip(img + noisy, -1, 1)\n", + "\n", + " def Frequency_tar(label):\n", + " return 13\n", + "\n", + " test_transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Resize((32, 32)),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", + " ])\n", + "\n", + " poison_method = ((None, Frequency), Frequency_tar)\n", + " _, _, asr_dataset2, pacc_dataset2 = get_dataset('./data/gtsrb.npy', test_transform, poison_method, 13)\n", + "\n", + " return val_dataset, test_dataset, (asr_dataset, asr_dataset2), (pacc_dataset, pacc_dataset2), net\n", + "\n", + "\n", + "## Clean STL-10\n", + "def STL10_Clean():\n", + " test_transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Resize(224),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", + " ])\n", + "\n", + " poison_method = (None, None)\n", + " val_dataset, test_dataset, _, _ = get_dataset('./data/stl10.npy', test_transform, poison_method, -1)\n", + "\n", + "\n", + " net = torchvision.models.vgg16_bn()\n", + " net.load_state_dict(torch.load('./checkpoint/stl_10_vgg.pth', map_location='cuda:0'))\n", + " net = net.cuda()\n", + "\n", + " return val_dataset, test_dataset, None, None, net" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1Pz9FGtJ8Cdf" + }, + "source": [ + "# 2. Test attack effect\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "> Attack setting\n", + "\n", + "\n", + "| | Case 1 | Case 2 | Case 3 | Case 4 | Case 5 |\n", + "|:-------------:|:--------------------:|:------------------:|:---------------------:|:------------------:|:--------------------:|\n", + "| Model | VIT-Tiny | ResNet-18 | ResNet-18 | GoogLenet | VGG16-bn |\n", + "| Dataset | PubFig | CIFAR-10 | Tiny-ImageNet | GTSRB | STL-10 |\n", + "| Dataset Info | 224\\*224\\*3 83 Classes | 32\\*32\\*3 10 Classes | 224\\*224\\*3 200 Classes | 32\\*32\\*3 43 Classes | 224\\*224\\*3 10 Classes |\n", + "| Poison Method | BadNets All2All | SIG | Narcissus | WaNet & Frequency | N/A |\n", + "| Target Label | All | 6 | 2 | 2 & 13 | N/A |\n", + "| Defense Time | 1350 S | 900 S | 1800 S | 690 S | 450 S |" + ], + "metadata": { + "id": "J1LR4re84sNt" + } + }, + { + "cell_type": "code", + "source": [ + "## Test Case-1\n", + "print(\"----------------- Testing attack: PubFig all2all -----------------\")\n", + "_, test_dataset, asr_dataset, pacc_dataset, model = PubFig_all2all()\n", + "print('ACC:%.3f%%' % (100 * get_results(model, test_dataset)))\n", + "print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))\n", + "print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))\n", + "## Test Case-2\n", + "print(\"----------------- Testing attack: CIFAR-10 SIG -----------------\")\n", + "_, test_dataset, asr_dataset, pacc_dataset, model = CIFAR10_SIG()\n", + "print('ACC:%.3f%%' % (100 * get_results(model, test_dataset)))\n", + "print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))\n", + "print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))\n", + "## Test Case-3\n", + "print(\"----------------- Testing attack: Tiny-Imagenet Narcissus -----------------\")\n", + "_, test_dataset, asr_dataset, pacc_dataset, model = TinyImangeNet_Narcissus()\n", + "print('ACC:%.3f%%' % (100 * get_results(model, test_dataset)))\n", + "print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))\n", + "print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))\n", + "## Test Case-4\n", + "print(\"----------------- Testing attack: GTSRB WaNet & Smooth -----------------\")\n", + "_, test_dataset, asr_dataset, pacc_dataset, model = GTSRB_WaNetFrequency()\n", + "print('ACC:%.3f%%' % (100 * get_results(model, test_dataset)))\n", + "print('WaNet ASR %.3f%%' % (100 * get_results(model, asr_dataset[0])))\n", + "print('WaNet PACC %.3f%%' % (100 * get_results(model, pacc_dataset[0])))\n", + "print('Smooth ASR %.3f%%' % (100 * get_results(model, asr_dataset[1])))\n", + "print('Smooth PACC %.3f%%' % (100 * get_results(model, pacc_dataset[1])))\n", + "## Test Case-5\n", + "print(\"----------------- Testing attack: STL-10 -----------------\")\n", + "_, test_dataset, _, _, model = STL10_Clean()\n", + "print('ACC:%.3f%%' % (100 * get_results(model, test_dataset)))" + ], + "metadata": { + "id": "00Ts2YrN8m15", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 263 + }, + "outputId": "2bcce9fd-38fb-497a-e9d4-41d61ef8f9ad" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TV_oFkq28Cdg" + }, + "source": [ + "# 3. Baseline defense" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "11y75jK98Cdg" + }, + "outputs": [], + "source": [ + "def test_defense(defense_method):\n", + " models = []\n", + " ## Test Pubfig all2all\n", + " # print(\"----------------- Testing defense: PubFig all2all -----------------\")\n", + " # val_dataset, _, _, _, model = PubFig_all2all()\n", + " # try:\n", + " # model = func_timeout(1350, defense_method, args=(model, val_dataset, 1350))\n", + " # except FunctionTimedOut:\n", + " # print(\"This test case exceed the maximum executable time!\\n\")\n", + " # models.append(model)\n", + "\n", + " # ## Test CIFAR-10 SIG\n", + " print(\"----------------- Testing defense: CIFAR-10 SIG -----------------\")\n", + " val_dataset, _, _, _, model = CIFAR10_SIG()\n", + " try:\n", + " model = func_timeout(900, defense_method, args=(model, val_dataset,900))\n", + " except FunctionTimedOut:\n", + " print ( \"This test case exceed the maximum executable time!\\n\")\n", + " models.append(model)\n", + " #\n", + " # ## Test Tiny-Imagenet Narcissus\n", + " # print(\"----------------- Testing defense: Tiny-Imagenet Narcissus -----------------\")\n", + " # val_dataset, _, _, _, model = TinyImangeNet_Narcissus()\n", + " # try:\n", + " # model = func_timeout(1800, defense_method, args=(model, val_dataset,1800))\n", + " # except FunctionTimedOut:\n", + " # print ( \"This test case exceed the maximum executable time!\\n\")\n", + " # models.append(model)\n", + " #\n", + " # ## Test GTSRB WaNet & Smooth\n", + " # print(\"----------------- Testing defense: GTSRB WaNet & Smooth -----------------\")\n", + " # val_dataset, _, _, _, model = GTSRB_WaNetFrequency()\n", + " # try:\n", + " # model = func_timeout(690, defense_method, args=(model, val_dataset,690))\n", + " # except FunctionTimedOut:\n", + " # print ( \"This test case exceed the maximum executable time!\\n\")\n", + " # models.append(model)\n", + " #\n", + " # ## Test STL-10\n", + " # print(\"----------------- Testing defense: STL-10 -----------------\")\n", + " # val_dataset, _, _, _, model = STL10_Clean()\n", + " # try:\n", + " # model = func_timeout(450, defense_method, args=(model, val_dataset,450))\n", + " # except FunctionTimedOut:\n", + " # print ( \"This test case exceed the maximum executable time!\\n\")\n", + " # models.append(model)\n", + " return models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2A4BC_518Cdg", + "cellView": "form" + }, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "\n", + "#@title I-BAU Defense\n", + "def IBAU(net, val_dataset, allow_time):\n", + " '''Code from https://github.com/YiZeng623/I-BAU'''\n", + " allow_time = allow_time * 1000\n", + "\n", + " val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=4, shuffle=True,\n", + " drop_last=True)\n", + "\n", + " images_list, labels_list = [], []\n", + " for index, (images, labels) in enumerate(val_dataloader):\n", + " images_list.append(images)\n", + " labels_list.append(labels)\n", + "\n", + " def loss_inner(perturb, model_params):\n", + " images = images_list[0].to(device)\n", + " labels = labels_list[0].long().to(device)\n", + " per_img = images + perturb[0]\n", + " per_logits = net.forward(per_img)\n", + " loss = F.cross_entropy(per_logits, labels, reduction='none')\n", + " loss_regu = torch.mean(-loss) + 0.001 * torch.pow(torch.norm(perturb[0]), 2)\n", + " return loss_regu\n", + "\n", + " def loss_outer(perturb, model_params):\n", + " random_pick = np.where(np.random.uniform(0, 1, 32) > 0.97)[0].shape[0]\n", + "\n", + " images, labels = images_list[batchnum].to(device), labels_list[batchnum].long().to(device)\n", + " patching = torch.zeros_like(images, device='cuda')\n", + " number = images.shape[0]\n", + " random_pick = min(number, random_pick)\n", + " rand_idx = random.sample(list(np.arange(number)), random_pick)\n", + " patching[rand_idx] = perturb[0]\n", + " unlearn_imgs = images + patching\n", + " logits = net(unlearn_imgs)\n", + " criterion = nn.CrossEntropyLoss()\n", + " loss = criterion(logits, labels)\n", + " return loss\n", + "\n", + " def get_lr(net, loader):\n", + " lr_list = [0.1 ** i for i in range(2, 8)]\n", + " acc_list = []\n", + " for i in range(len(lr_list)):\n", + " copy_net = copy.deepcopy(net)\n", + " copy_net = copy_net.cuda()\n", + " optimizer = torch.optim.Adam(copy_net.parameters(), lr=lr_list[i])\n", + " for _, data in enumerate(loader, 0):\n", + " length = len(loader)\n", + " inputs, labels = data\n", + " inputs, labels = inputs.to(device), labels.type(torch.LongTensor).to(device)\n", + " optimizer.zero_grad()\n", + "\n", + " # forward + backward\n", + " outputs = copy_net(inputs)\n", + " loss = F.cross_entropy(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " acc_list.append(get_results(copy_net, loader.dataset))\n", + " print(\"lr = \" + str(lr_list[i]) + \" ACC: \" + str(acc_list[-1] * 100))\n", + " return 0.1 ** (acc_list.index(max(acc_list)) + 2)\n", + "\n", + " start = torch.cuda.Event(enable_timing=True)\n", + " end = torch.cuda.Event(enable_timing=True)\n", + "\n", + " #contral the time\n", + " every_time = []\n", + " for _ in range(5):\n", + " every_time.append(0)\n", + "\n", + " start.record()\n", + "\n", + " curr_lr = get_lr(net, val_dataloader)\n", + " net = net.cuda()\n", + " outer_opt = torch.optim.Adam(net.parameters(), lr=curr_lr)\n", + " inner_opt = GradientDescent(loss_inner, 0.1)\n", + "\n", + " end.record()\n", + " torch.cuda.synchronize()\n", + " every_time.append(start.elapsed_time(end))\n", + "\n", + " net.train()\n", + " while (allow_time - np.sum(every_time)) > (np.mean(every_time[-5:]) * 2) and len(every_time) < 155:\n", + " start.record()\n", + " batch_pert = torch.zeros_like(val_dataset[0][0].unsqueeze(0), requires_grad=True, device='cuda')\n", + " batch_lr = 0.0005 * val_dataset[0][0].shape[1] - 0.0155\n", + " batch_opt = torch.optim.Adam(params=[batch_pert], lr=batch_lr)\n", + "\n", + " for index, (images, labels) in enumerate(val_dataloader):\n", + " images = images.to(device)\n", + " ori_lab = torch.argmax(net.forward(images), axis=1).long()\n", + " per_logits = net.forward(images + batch_pert)\n", + " loss = -F.cross_entropy(per_logits, ori_lab) + 0.001 * torch.pow(torch.norm(batch_pert), 2)\n", + " batch_opt.zero_grad()\n", + " loss.backward(retain_graph=True)\n", + " # if index % 4 == 0:\n", + " batch_opt.step()\n", + "\n", + " #unlearn step\n", + " for batchnum in range(len(images_list)):\n", + " outer_opt.zero_grad()\n", + " fixed_point(batch_pert, list(net.parameters()), 5, inner_opt, loss_outer)\n", + " # if batchnum % 4 == 0:\n", + " outer_opt.step()\n", + "\n", + " print('Round:', len(every_time) - 5)\n", + " end.record()\n", + " torch.cuda.synchronize()\n", + " every_time.append(start.elapsed_time(end))\n", + " return net" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Jmy3w60L8Cdh", + "cellView": "form" + }, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "\n", + "#@title Neural Cleanse Defense\n", + "def neural_cleanse(model, val_dataset, allow_time):\n", + " '''Code from https://github.com/VinAIResearch/input-aware-backdoor-attack-release'''\n", + "\n", + " class RegressionModel(nn.Module):\n", + " def __init__(self, opt, init_mask, init_pattern, model):\n", + " self._EPSILON = opt.EPSILON\n", + " super(RegressionModel, self).__init__()\n", + " self.mask_tanh = nn.Parameter(torch.tensor(init_mask))\n", + " self.pattern_tanh = nn.Parameter(torch.tensor(init_pattern))\n", + "\n", + " self.classifier = copy.deepcopy(model)\n", + " for param in self.classifier.parameters():\n", + " param.requires_grad = False\n", + " self.classifier.eval()\n", + " self.classifier = self.classifier.cuda()\n", + "\n", + " def forward(self, x):\n", + " mask = self.get_raw_mask()\n", + " pattern = self.get_raw_pattern()\n", + " x = (1 - mask) * x + mask * pattern\n", + " return self.classifier(x)\n", + "\n", + " def get_raw_mask(self):\n", + " mask = nn.Tanh()(self.mask_tanh)\n", + " return mask / (2 + self._EPSILON) + 0.5\n", + "\n", + " def get_raw_pattern(self):\n", + " pattern = nn.Tanh()(self.pattern_tanh)\n", + " return pattern / (2 + self._EPSILON) + 0.5\n", + "\n", + " class Recorder:\n", + " def __init__(self, opt):\n", + " super().__init__()\n", + "\n", + " # Best optimization results\n", + " self.mask_best = None\n", + " self.pattern_best = None\n", + " self.reg_best = float(\"inf\")\n", + "\n", + " # Logs and counters for adjusting balance cost\n", + " self.logs = []\n", + " self.cost_set_counter = 0\n", + " self.cost_up_counter = 0\n", + " self.cost_down_counter = 0\n", + " self.cost_up_flag = False\n", + " self.cost_down_flag = False\n", + "\n", + " # Counter for early stop\n", + " self.early_stop_counter = 0\n", + " self.early_stop_reg_best = self.reg_best\n", + "\n", + " # Cost\n", + " self.cost = opt.init_cost\n", + " self.cost_multiplier_up = opt.cost_multiplier\n", + " self.cost_multiplier_down = opt.cost_multiplier ** 1.5\n", + "\n", + " def reset_state(self, opt):\n", + " self.cost = opt.init_cost\n", + " self.cost_up_counter = 0\n", + " self.cost_down_counter = 0\n", + " self.cost_up_flag = False\n", + " self.cost_down_flag = False\n", + " print(\"Initialize cost to {:f}\".format(self.cost))\n", + "\n", + " def train(opt, init_mask, init_pattern, model, val_dataset):\n", + "\n", + " test_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=128, num_workers=4, shuffle=False,\n", + " drop_last=True)\n", + "\n", + " # Build regression model\n", + " regression_model = RegressionModel(opt, init_mask, init_pattern, model).cuda()\n", + "\n", + " # Set optimizer\n", + " optimizerR = torch.optim.Adam(regression_model.parameters(), lr=opt.lr, betas=(0.5, 0.9))\n", + "\n", + " # Set recorder (for recording best result)\n", + " recorder = Recorder(opt)\n", + "\n", + " for epoch in range(opt.epoch):\n", + " early_stop = train_step(regression_model, optimizerR, test_dataloader, recorder, epoch, opt)\n", + " if early_stop:\n", + " break\n", + "\n", + " return recorder, opt\n", + "\n", + " def train_step(regression_model, optimizerR, dataloader, recorder, epoch, opt):\n", + " print(\"Epoch {} - Label: {}:\".format(epoch, opt.target_label))\n", + " # Set losses\n", + " cross_entropy = nn.CrossEntropyLoss()\n", + " total_pred = 0\n", + " true_pred = 0\n", + "\n", + " # Record loss for all mini-batches\n", + " loss_ce_list = []\n", + " loss_reg_list = []\n", + " loss_list = []\n", + " loss_acc_list = []\n", + "\n", + " # Set inner early stop flag\n", + " inner_early_stop_flag = False\n", + " for batch_idx, (inputs, labels) in enumerate(dataloader):\n", + " # Forwarding and update model\n", + " optimizerR.zero_grad()\n", + "\n", + " inputs = inputs.cuda()\n", + " sample_num = inputs.shape[0]\n", + " total_pred += sample_num\n", + " target_labels = torch.ones((sample_num), dtype=torch.int64).cuda() * opt.target_label\n", + " predictions = regression_model(inputs)\n", + "\n", + " loss_ce = cross_entropy(predictions, target_labels)\n", + " loss_reg = torch.norm(regression_model.get_raw_mask(), 2)\n", + " total_loss = loss_ce + recorder.cost * loss_reg\n", + " total_loss.backward()\n", + " optimizerR.step()\n", + "\n", + " # Record minibatch information to list\n", + " minibatch_accuracy = torch.sum(\n", + " torch.argmax(predictions, dim=1) == target_labels).detach() * 100.0 / sample_num\n", + " loss_ce_list.append(loss_ce.detach())\n", + " loss_reg_list.append(loss_reg.detach())\n", + " loss_list.append(total_loss.detach())\n", + " loss_acc_list.append(minibatch_accuracy)\n", + "\n", + " true_pred += torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach()\n", + "\n", + " loss_ce_list = torch.stack(loss_ce_list)\n", + " loss_reg_list = torch.stack(loss_reg_list)\n", + " loss_list = torch.stack(loss_list)\n", + " loss_acc_list = torch.stack(loss_acc_list)\n", + "\n", + " avg_loss_ce = torch.mean(loss_ce_list)\n", + " avg_loss_reg = torch.mean(loss_reg_list)\n", + " avg_loss_acc = torch.mean(loss_acc_list)\n", + "\n", + " # Check to save best mask or not\n", + " if avg_loss_acc >= opt.atk_succ_threshold and avg_loss_reg < recorder.reg_best:\n", + " recorder.mask_best = regression_model.get_raw_mask().detach()\n", + " recorder.pattern_best = regression_model.get_raw_pattern().detach()\n", + " recorder.reg_best = avg_loss_reg\n", + " print(\" Updated !!!\")\n", + "\n", + " # Show information\n", + " print(\n", + " \" Result: Accuracy: {:.3f} | Cross Entropy Loss: {:.6f} | Reg Loss: {:.6f} | Reg best: {:.6f}\".format(\n", + " true_pred * 100.0 / total_pred, avg_loss_ce, avg_loss_reg, recorder.reg_best\n", + " )\n", + " )\n", + "\n", + " # Check early stop\n", + " if opt.early_stop:\n", + " if recorder.reg_best < float(\"inf\"):\n", + " if recorder.reg_best >= opt.early_stop_threshold * recorder.early_stop_reg_best:\n", + " recorder.early_stop_counter += 1\n", + " else:\n", + " recorder.early_stop_counter = 0\n", + "\n", + " recorder.early_stop_reg_best = min(recorder.early_stop_reg_best, recorder.reg_best)\n", + "\n", + " if (\n", + " recorder.cost_down_flag\n", + " and recorder.cost_up_flag\n", + " and recorder.early_stop_counter >= opt.early_stop_patience\n", + " ):\n", + " print(\"Early_stop !!!\")\n", + " inner_early_stop_flag = True\n", + "\n", + " if not inner_early_stop_flag:\n", + " # Check cost modification\n", + " if recorder.cost == 0 and avg_loss_acc >= opt.atk_succ_threshold:\n", + " recorder.cost_set_counter += 1\n", + " if recorder.cost_set_counter >= opt.patience:\n", + " recorder.reset_state(opt)\n", + " else:\n", + " recorder.cost_set_counter = 0\n", + "\n", + " if avg_loss_acc >= opt.atk_succ_threshold:\n", + " recorder.cost_up_counter += 1\n", + " recorder.cost_down_counter = 0\n", + " else:\n", + " recorder.cost_up_counter = 0\n", + " recorder.cost_down_counter += 1\n", + "\n", + " if recorder.cost_up_counter >= opt.patience:\n", + " recorder.cost_up_counter = 0\n", + " print(\"Up cost from {} to {}\".format(recorder.cost, recorder.cost * recorder.cost_multiplier_up))\n", + " recorder.cost *= recorder.cost_multiplier_up\n", + " recorder.cost_up_flag = True\n", + "\n", + " elif recorder.cost_down_counter >= opt.patience:\n", + " recorder.cost_down_counter = 0\n", + " print(\"Down cost from {} to {}\".format(recorder.cost, recorder.cost / recorder.cost_multiplier_down))\n", + " recorder.cost /= recorder.cost_multiplier_down\n", + " recorder.cost_down_flag = True\n", + "\n", + " # Save the final version\n", + " if recorder.mask_best is None:\n", + " recorder.mask_best = regression_model.get_raw_mask().detach()\n", + " recorder.pattern_best = regression_model.get_raw_pattern().detach()\n", + "\n", + " return inner_early_stop_flag\n", + "\n", + " class opt:\n", + " total_label = np.unique(val_dataset.targets).shape[0]\n", + " input_height, input_width, input_channel = val_dataset[0][0].shape[1], val_dataset[0][0].shape[2], \\\n", + " val_dataset[0][0].shape[0]\n", + " EPSILON = 1e-7\n", + " lr = 1e-1\n", + " init_cost = 1e-3\n", + " cost_multiplier = 2.0\n", + " epoch = 1\n", + " atk_succ_threshold = 99.0\n", + " early_stop_threshold = 99.0\n", + " early_stop = True\n", + " patience = 5\n", + "\n", + " opt = opt()\n", + "\n", + " init_mask = np.ones((1, opt.input_height, opt.input_width)).astype(np.float32)\n", + " init_pattern = np.ones((opt.input_channel, opt.input_height, opt.input_width)).astype(np.float32)\n", + "\n", + " masks = []\n", + " patterns = []\n", + " idx_mapping = {}\n", + "\n", + " for target_label in range(opt.total_label):\n", + " print(\"----------------- Analyzing label: {} -----------------\".format(target_label))\n", + " opt.target_label = target_label\n", + " recorder, opt = train(opt, init_mask, init_pattern, model, val_dataset)\n", + "\n", + " mask = recorder.mask_best\n", + " masks.append(mask)\n", + " pattern = recorder.pattern_best\n", + " patterns.append(pattern)\n", + "\n", + " idx_mapping[target_label] = len(masks) - 1\n", + "\n", + " l1_norm_list = torch.stack([torch.sum(torch.abs(m)) for m in masks])\n", + " print(\"{} labels found\".format(len(l1_norm_list)))\n", + " print(\"Norm values: {}\".format(l1_norm_list))\n", + "\n", + " def outlier_detection(l1_norm_list, idx_mapping, opt):\n", + " print(\"-\" * 30)\n", + " print(\"Determining whether model is backdoor\")\n", + " consistency_constant = 1.4826\n", + " median = torch.median(l1_norm_list)\n", + " mad = consistency_constant * torch.median(torch.abs(l1_norm_list - median))\n", + " min_mad = torch.abs(torch.min(l1_norm_list) - median) / mad\n", + "\n", + " print(\"Median: {}, MAD: {}\".format(median, mad))\n", + " print(\"Anomaly index: {}\".format(min_mad))\n", + "\n", + " if min_mad < 2:\n", + " print(\"Not a backdoor model\")\n", + " else:\n", + " print(\"This is a backdoor model\")\n", + "\n", + " flag_list = []\n", + " for y_label in idx_mapping:\n", + " if l1_norm_list[idx_mapping[y_label]] > median:\n", + " continue\n", + " if torch.abs(l1_norm_list[idx_mapping[y_label]] - median) / mad > 2:\n", + " flag_list.append((y_label, l1_norm_list[idx_mapping[y_label]]))\n", + "\n", + " if len(flag_list) > 0:\n", + " flag_list = sorted(flag_list, key=lambda x: x[1])\n", + "\n", + " print(\n", + " \"Flagged label list: {}\".format(\n", + " \",\".join([\"{}: {}\".format(y_label, l_norm) for y_label, l_norm in flag_list]))\n", + " )\n", + "\n", + " return [y_label for y_label, _ in flag_list]\n", + "\n", + " poi_label_list = outlier_detection(l1_norm_list, idx_mapping, opt)\n", + "\n", + " if len(poi_label_list) == 0:\n", + " return model\n", + "\n", + " class unlearning_ds(Dataset):\n", + " def __init__(self, dataset, mask, trigger, patch_ratio):\n", + " self.dataset = dataset\n", + " self.patch_list = random.sample(list(np.arange(len(dataset))), int(len(dataset) * patch_ratio))\n", + " self.mask = mask\n", + " self.trigger = trigger\n", + "\n", + " def __getitem__(self, idx):\n", + " image = self.dataset[idx][0]\n", + " label = self.dataset[idx][1]\n", + " if idx in self.patch_list:\n", + " image = (image + self.mask * (self.trigger - image))\n", + " image = torch.clamp(image, -1, 1)\n", + " return (image, label)\n", + "\n", + " def __len__(self):\n", + " return len(self.dataset)\n", + "\n", + " for i in poi_label_list:\n", + " curr_masks = masks[i].cpu()\n", + " curr_pattern = patterns[i].cpu()\n", + " ul_set = unlearning_ds(val_dataset, curr_masks, curr_pattern, 0.2)\n", + " ul_loader = torch.utils.data.DataLoader(ul_set, batch_size=128, num_workers=4, shuffle=True, drop_last=True)\n", + "\n", + " model.train()\n", + " outer_opt = torch.optim.SGD(params=model.parameters(), lr=8e-2)\n", + " criterion = nn.CrossEntropyLoss()\n", + " for _ in range(10):\n", + " train_loss = 0\n", + " correct = 0\n", + " total = 0\n", + " acc_rec = 0\n", + " for batch_idx, (inputs, targets) in enumerate(ul_loader):\n", + " inputs, targets = inputs.cuda(), targets.type(torch.LongTensor).cuda()\n", + " outer_opt.zero_grad()\n", + " outputs = model(inputs)\n", + " loss = criterion(outputs, targets)\n", + " loss.backward()\n", + " outer_opt.step()\n", + "\n", + " train_loss += loss.item()\n", + " _, predicted = outputs.max(1)\n", + " total += targets.size(0)\n", + " correct += predicted.eq(targets).sum().item()\n", + " print('Unlearn Acc: %.3f%% (%d/%d)'\n", + " % (100. * correct / total, correct, total))\n", + "\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GFxsTR3E8Cdj" + }, + "source": [ + "# 4. Implement your defense method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8KQKqgGs8Cdj" + }, + "outputs": [], + "source": [ + "import fastres\n", + "import functools\n", + "from functools import partial\n", + "import math\n", + "import os\n", + "import copy\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import nn\n", + "\n", + "import torchvision\n", + "from torchvision import transforms\n", + "\n", + "# set global defaults (in this particular file) for convolutions\n", + "default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False}\n", + "\n", + "batchsize = 64\n", + "bias_scaler = 56\n", + "# To replicate the ~95.78%-accuracy-in-113-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->85, ['ema'] epochs 10->75, cutmix_size 3->9, and cutmix_epochs 6->75\n", + "hyp = {\n", + " 'opt': {\n", + " 'bias_lr': 1.64 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))\n", + " 'non_bias_lr': 1.64 / 512,\n", + " 'bias_decay': 1.08 * 6.45e-4 * batchsize/bias_scaler,\n", + " 'non_bias_decay': 1.08 * 6.45e-4 * batchsize,\n", + " 'scaling_factor': 1./9,\n", + " 'percent_start': .23,\n", + " 'loss_scale_scaler': 1./128, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)\n", + " },\n", + " 'net': {\n", + " 'whitening': {\n", + " 'kernel_size': 2,\n", + " 'num_examples': 50000,\n", + " },\n", + " 'batch_norm_momentum': .5, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )\n", + " 'conv_norm_pow': 2.6,\n", + " 'cutmix_size': 9,\n", + " 'cutmix_epochs': 180,\n", + " 'pad_amount': 2,\n", + " 'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly\n", + " },\n", + " 'misc': {\n", + " 'ema': {\n", + " 'epochs': 180, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too\n", + " 'decay_base': .95,\n", + " 'decay_pow': 3.,\n", + " 'every_n_steps': 5,\n", + " },\n", + " 'train_epochs': 200,\n", + " 'device': 'cuda',\n", + " 'data_location': 'data.pt',\n", + " }\n", + "}\n", + "\n", + "\n", + "\n", + "\n", + "def clean_model(net, train_dataset, allow_time):\n", + " #############################################\n", + " # Dataloader #\n", + " #############################################\n", + " transform = transforms.Compose([\n", + " transforms.ToTensor()])\n", + "\n", + " # use the dataloader to get a single batch of all the dataset items at once.\n", + " train_dataset_gpu_loader = torch.utils.data.DataLoader(train_dataset, batch_size=len(train_dataset), drop_last=True,shuffle=True, num_workers=4, persistent_workers=False)\n", + " eval_dataset_gpu_loader = torch.utils.data.DataLoader(train_dataset, batch_size=len(train_dataset_gpu_loader), drop_last=True, shuffle=False, num_workers=1, persistent_workers=False)\n", + "\n", + " train_dataset_gpu = {}\n", + " eval_dataset_gpu = {}\n", + "\n", + " train_dataset_gpu['images'], train_dataset_gpu['targets'] = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(train_dataset_gpu_loader))]\n", + " eval_dataset_gpu['images'], eval_dataset_gpu['targets'] = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(eval_dataset_gpu_loader)) ]\n", + "\n", + " cifar10_std, cifar10_mean = torch.std_mean(train_dataset_gpu['images'], dim=(0, 2, 3)) # dynamically calculate the std and mean from the data. this shortens the code and should help us adapt to new datasets!\n", + "\n", + " def batch_normalize_images(input_images, mean, std):\n", + " return (input_images - mean.view(1, -1, 1, 1)) / std.view(1, -1, 1, 1)\n", + "\n", + " # preload with our mean and std\n", + " batch_normalize_images = partial(batch_normalize_images, mean=cifar10_mean, std=cifar10_std)\n", + "\n", + " ## Batch normalize datasets, now. Wowie. We did it! We should take a break and make some tea now.\n", + " train_dataset_gpu['images'] = batch_normalize_images(train_dataset_gpu['images'])\n", + " eval_dataset_gpu['images'] = batch_normalize_images(eval_dataset_gpu['images'])\n", + "\n", + " data = {\n", + " 'train': train_dataset_gpu,\n", + " 'eval': eval_dataset_gpu\n", + " }\n", + "\n", + " ## Convert dataset to FP16 now for the rest of the process....\n", + " data['train']['images'] = data['train']['images'].half().requires_grad_(False)\n", + " data['eval']['images'] = data['eval']['images'].half().requires_grad_(False)\n", + "\n", + " # Convert this to one-hot to support the usage of cutmix (or whatever strange label tricks/magic you desire!)\n", + " data['train']['targets'] = F.one_hot(data['train']['targets']).half()\n", + " data['eval']['targets'] = F.one_hot(data['eval']['targets']).half()\n", + "\n", + " torch.save(data, hyp['misc']['data_location'])\n", + "\n", + "\n", + " ## This is effectively instantaneous, and takes us practically straight to where the dataloader-loaded dataset would be. :)\n", + " ## So as long as you run the above loading process once, and keep the file on the disc it's specified by default in the above\n", + " ## hyp dictionary, then we should be good. :)\n", + " data = torch.load(hyp['misc']['data_location'])\n", + "\n", + " ## As you'll note above and below, one difference is that we don't count loading the raw data to GPU since it's such a variable operation, and can sort of get in the way\n", + " ## of measuring other things. That said, measuring the preprocessing (outside the padding) is still important to us.\n", + "\n", + " # Pad the GPU training dataset\n", + " if hyp['net']['pad_amount'] > 0:\n", + " ## Uncomfortable shorthand, but basically we pad evenly on all _4_ sides with the pad_amount specified in the original dictionary\n", + " data['train']['images'] = F.pad(data['train']['images'], (hyp['net']['pad_amount'],)*4, 'reflect')\n", + "\n", + " # Initializing constants for the whole run.\n", + " net_ema = None ## Reset any existing network emas, we want to have _something_ to check for existence so we can initialize the EMA right from where the network is during training\n", + " ## (as opposed to initializing the network_ema from the randomly-initialized starter network, then forcing it to play catch-up all of a sudden in the last several epochs)\n", + "\n", + " total_time_seconds = 0.\n", + " current_steps = 0.\n", + "\n", + " # TODO: Doesn't currently account for partial epochs really (since we're not doing \"real\" epochs across the whole batchsize)....\n", + " num_steps_per_epoch = len(data['train']['images']) // batchsize\n", + " total_train_steps = math.ceil(num_steps_per_epoch * hyp['misc']['train_epochs'])\n", + " ema_epoch_start = math.floor(hyp['misc']['train_epochs']) - hyp['misc']['ema']['epochs']\n", + "\n", + " ## I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of \"every n\" steps\n", + " ## to somewhat accomodate for whatever the expected information intake rate is. The tradeoff I believe, though, is that this is to some degree noisier as we\n", + " ## are intaking fewer samples of our distribution-over-time, with a higher individual weight each. This can be good or bad depending upon what we want.\n", + " projected_ema_decay_val = hyp['misc']['ema']['decay_base'] ** hyp['misc']['ema']['every_n_steps']\n", + "\n", + " # Adjust pct_start based upon how many epochs we need to finetune the ema at a low lr for\n", + " pct_start = hyp['opt']['percent_start'] #* (total_train_steps/(total_train_steps - num_low_lr_steps_for_ema))\n", + "\n", + " ## Stowing the creation of these into a helper function to make things a bit more readable....\n", + " non_bias_params, bias_params = fastres.init_split_parameter_dictionaries(net)\n", + "\n", + " # One optimizer for the regular network, and one for the biases. This allows us to use the superconvergence onecycle training policy for our networks....\n", + " opt = torch.optim.SGD(**non_bias_params)\n", + " opt_bias = torch.optim.SGD(**bias_params)\n", + "\n", + " ## Not the most intuitive, but this basically takes us from ~0 to max_lr at the point pct_start, then down to .1 * max_lr at the end (since 1e16 * 1e-15 = .1 --\n", + " ## This quirk is because the final lr value is calculated from the starting lr value and not from the maximum lr value set during training)\n", + " initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D\n", + " final_lr_ratio = .07 # Actually pretty important, apparently!\n", + " lr_sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False)\n", + " lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False)\n", + "\n", + " ## For accurately timing GPU code\n", + " starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)\n", + " torch.cuda.synchronize() ## clean up any pre-net setup operations\n", + "\n", + "\n", + " if True: ## Sometimes we need a conditional/for loop here, this is placed to save the trouble of needing to indent\n", + " for epoch in range(math.ceil(hyp['misc']['train_epochs'])):\n", + " #################\n", + " # Training Mode #\n", + " #################\n", + " torch.cuda.synchronize()\n", + " starter.record()\n", + " net.train()\n", + "\n", + " loss_train = None\n", + " accuracy_train = None\n", + "\n", + " cutmix_size = hyp['net']['cutmix_size'] if epoch >= hyp['misc']['train_epochs'] - hyp['net']['cutmix_epochs'] else 0\n", + " epoch_fraction = 1 if epoch + 1 < hyp['misc']['train_epochs'] else hyp['misc']['train_epochs'] % 1 # We need to know if we're running a partial epoch or not.\n", + "\n", + " for epoch_step, (inputs, targets) in enumerate(fastres.get_batches(data, key='train', batchsize=batchsize, epoch_fraction=epoch_fraction, cutmix_size=cutmix_size)):\n", + " ## Run everything through the network\n", + " outputs = net(inputs)\n", + "\n", + " loss_batchsize_scaler = 512/batchsize # to scale to keep things at a relatively similar amount of regularization when we change our batchsize since we're summing over the whole batch\n", + " ## If you want to add other losses or hack around with the loss, you can do that here.\n", + " loss = fastres.loss_fn(outputs, targets).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) ## Note, as noted in the original blog posts, the summing here does a kind of loss scaling\n", + " ## (and is thus batchsize dependent as a result). This can be somewhat good or bad, depending...\n", + "\n", + " # we only take the last-saved accs and losses from train\n", + " if epoch_step % 50 == 0:\n", + " train_acc = (outputs.detach().argmax(-1) == targets.argmax(-1)).float().mean().item()\n", + " train_loss = loss.detach().cpu().item()/(batchsize*loss_batchsize_scaler)\n", + "\n", + " loss.backward()\n", + "\n", + " ## Step for each optimizer, in turn.\n", + " opt.step()\n", + " opt_bias.step()\n", + "\n", + " # We only want to step the lr_schedulers while we have training steps to consume. Otherwise we get a not-so-friendly error from PyTorch\n", + " lr_sched.step()\n", + " lr_sched_bias.step()\n", + "\n", + " ## Using 'set_to_none' I believe is slightly faster (albeit riskier w/ funky gradient update workflows) than under the default 'set to zero' method\n", + " opt.zero_grad(set_to_none=True)\n", + " opt_bias.zero_grad(set_to_none=True)\n", + " current_steps += 1\n", + "\n", + " if epoch >= ema_epoch_start and current_steps % hyp['misc']['ema']['every_n_steps'] == 0:\n", + " ## Initialize the ema from the network at this point in time if it does not already exist.... :D\n", + " if net_ema is None: # don't snapshot the network yet if so!\n", + " net_ema = fastres.NetworkEMA(net)\n", + " continue\n", + " # We warm up our ema's decay/momentum value over training exponentially according to the hyp config dictionary (this lets us move fast, then average strongly at the end).\n", + " net_ema.update(net, decay=projected_ema_decay_val*(current_steps/total_train_steps)**hyp['misc']['ema']['decay_pow'])\n", + "\n", + " ender.record()\n", + " torch.cuda.synchronize()\n", + " total_time_seconds += 1e-3 * starter.elapsed_time(ender)\n", + "\n", + " ####################\n", + " # Evaluation Mode #\n", + " ####################\n", + " net.eval()\n", + "\n", + " eval_batchsize = 2500\n", + " assert data['eval']['images'].shape[0] % eval_batchsize == 0, \"Error: The eval batchsize must evenly divide the eval dataset (for now, we don't have drop_remainder implemented yet).\"\n", + " loss_list_val, acc_list, acc_list_ema = [], [], []\n", + "\n", + " with torch.no_grad():\n", + " for inputs, targets in fastres.get_batches(data, key='eval', batchsize=eval_batchsize):\n", + " if epoch >= ema_epoch_start:\n", + " outputs = net_ema(inputs)\n", + " acc_list_ema.append((outputs.argmax(-1) == targets.argmax(-1)).float().mean())\n", + " outputs = net(inputs)\n", + " loss_list_val.append(fastres.loss_fn(outputs, targets).float().mean())\n", + " acc_list.append((outputs.argmax(-1) == targets.argmax(-1)).float().mean())\n", + "\n", + " val_acc = torch.stack(acc_list).mean().item()\n", + " ema_val_acc = None\n", + " # TODO: We can fuse these two operations (just above and below) all-together like :D :))))\n", + " if epoch >= ema_epoch_start:\n", + " ema_val_acc = torch.stack(acc_list_ema).mean().item()\n", + "\n", + " val_loss = torch.stack(loss_list_val).mean().item()\n", + " # We basically need to look up local variables by name so we can have the names, so we can pad to the proper column width.\n", + " ## Printing stuff in the terminal can get tricky and this used to use an outside library, but some of the required stuff seemed even\n", + " ## more heinous than this, unfortunately. So we switched to the \"more simple\" version of this!\n", + " format_for_table = lambda x, locals: (f\"{locals[x]}\".rjust(len(x))) \\\n", + " if type(locals[x]) == int else \"{:0.4f}\".format(locals[x]).rjust(len(x)) \\\n", + " if locals[x] is not None \\\n", + " else \" \"*len(x)\n", + "\n", + " # Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....)\n", + " ## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round.\n", + " fastres.print_training_details(list(map(partial(format_for_table, locals=locals()), fastres.logging_columns_list)), is_final_entry=(epoch >= math.ceil(hyp['misc']['train_epochs'] - 1)))\n", + "\n", + "\n", + " return net" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 5. Test defense" + ], + "metadata": { + "id": "THKu7ewXHYGi" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "w6fo4MlW8Cdj" + }, + "outputs": [], + "source": [ + "# Get the defended model\n", + "models = test_defense(clean_model)\n", + "\n", + "# Test all attack\n", + "## Test Pubfig all2all\n", + "# print(\"----------------- Testing defense result: PubFig all2all -----------------\")\n", + "# _, test_dataset, asr_dataset, pacc_dataset, _ = PubFig_all2all()\n", + "# model = models\n", + "# print('ACC:%.3f%%' % (100 * get_results(model, test_dataset)))\n", + "# print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))\n", + "# print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))\n", + "\n", + "# ## Test CIFAR-10 SIG\n", + "print(\"----------------- Testing defense result: CIFAR-10 SIG -----------------\")\n", + "_, test_dataset, asr_dataset, pacc_dataset, _ = CIFAR10_SIG()\n", + "model = models[0]\n", + "print('ACC:%.3f%%' % (100 * get_results(model, test_dataset)))\n", + "print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))\n", + "print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))\n", + "#\n", + "# ## Test Tiny-Imagenet Narcissus\n", + "print(\"----------------- Testing defense result: Tiny-Imagenet Narcissus -----------------\")\n", + "_, test_dataset, asr_dataset, pacc_dataset, _ = TinyImangeNet_Narcissus()\n", + "model = models[1]\n", + "print('ACC:%.3f%%' % (100 * get_results(model, test_dataset)))\n", + "print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))\n", + "print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))\n", + "#\n", + "# ## Test GTSRB WaNet & Smooth\n", + "print(\"----------------- Testing defense result: GTSRB WaNet & Smooth -----------------\")\n", + "_, test_dataset, asr_dataset, pacc_dataset, _ = GTSRB_WaNetFrequency()\n", + "model = models[2]\n", + "print('ACC:%.3f%%' % (100 * get_results(model, test_dataset)))\n", + "print('WaNet ASR %.3f%%' % (100 * get_results(model, asr_dataset[0])))\n", + "print('WaNet PACC %.3f%%' % (100 * get_results(model, pacc_dataset[0])))\n", + "print('Smooth ASR %.3f%%' % (100 * get_results(model, asr_dataset[1])))\n", + "print('Smooth PACC %.3f%%' % (100 * get_results(model, pacc_dataset[1])))\n", + "#\n", + "# ## Test STL-10\n", + "print(\"----------------- Testing defense result: STL-10 -----------------\")\n", + "_, test_dataset, _, _, _ = STL10_Clean()\n", + "model = models[3]\n", + "print('ACC:%.3f%%' % (100 * get_results(model, test_dataset)))" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 6. For colab user to release GPU memory" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "! apt-get install psmisc\n", + "! /opt/bin/nvidia-smi\n", + "! sudo fuser -v/dev/nvidia *" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "! kill -9[PID]" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "7b8be34f2a64f133f414bd034f75b72cc1c8d29070f6944ffe8bd65ff6cd5b9f" + } + }, + "colab": { + "provenance": [], + "collapsed_sections": [ + "WkI4fII__74u", + "TV_oFkq28Cdg", + "GFxsTR3E8Cdj", + "THKu7ewXHYGi" + ] + }, + "gpuClass": "standard" + }, + "nbformat": 4, + "nbformat_minor": 0 +}