Skip to content

Commit

Permalink
Add support for spatial decorrelation, diversity, & more (#7)
Browse files Browse the repository at this point in the history
**New Features:**

* Added support for spatial decorrelation via Fast Fourier transform (FFT). This greatly improves visualization quality.
* Added `vis_diverse.py` script that uses cosine similarity to extract different features from polysemantic channels/neurons.
* The `vis_multi.py`, `vis.py`, and `vis_diverse.py` scripts now support neuron extraction via the  `-extract_neuron` flag/parameter.

**Bug Fixes:**

* The `vis_multi.py` script's `-image_size` parameter now works correctly.
* The `-use_rgb` parameter is now separate from the `-not_caffe` parameter in the `calc_ms.py` script.
  • Loading branch information
ProGamerGov committed Sep 22, 2020
1 parent 32fcd67 commit 55f4ef1
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 21 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,10 @@ After training a new DeepDream model, you'll need to test it's visualizations. T
* `-optimizer`: The optimization algorithm to use; either `lbfgs` or `adam`; default is `adam`.
* `-num_iterations`: Default is `500`.
* `-layer`: The specific layer you wish to use. Default is set to `fc`.
* `-extract_neuron`: If this flag is enabled, the center neuron will be extracted from each channel.
* `-image_size`: A comma separated list of `<height>,<width>` to use for the output image. Default is set to `224,224`.
* `-jitter`: The amount of image jitter to use for preprocessing. Default is `32`.
* `-fft_decorrelation`: Whether or not to use FFT spatial decorrelation. If enabled, a lower learning rate should be used.

**Processing options:**
* `-batch_size`: How many channel visualization images to create in each batch. Default is `10`.
Expand Down Expand Up @@ -219,8 +221,10 @@ This script lets you create DeepDream hallucinations with trained GoogleNet mode
* `-content_image`: Path to your input image. If no input image is specified, random noise is used instead.
* `-layer`: The specific layer you wish to use. Default is set to `mixed5a`.
* `-channel`: The specific layer channel you wish to use. Default is set to `-1` to disable specific channel selection.
* `-extract_neuron`: If this flag is enabled, the center neuron will be extracted from the channel selected by the `-channel` parameter.
* `-image_size`: A comma separated list of `<height>,<width>` to use for the output image. Default is set to `224,224`.
* `-jitter`: The amount of image jitter to use for preprocessing. Default is `32`.
* `-fft_decorrelation`: Whether or not to use FFT spatial decorrelation. If enabled, a lower learning rate should be used.

**Only Required If Model Doesn't Contain Them, Options**:
* `-data_mean`: Your precalculated list of mean values that was used to train the model, if they weren't saved inside the model.
Expand Down
6 changes: 3 additions & 3 deletions data_tools/calc_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def main_calc(params):
if not params.not_caffe:
range_change = transforms.Compose([transforms.Lambda(lambda x: x*255)])
transform_list += [range_change]
if not params.use_rgb:
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
transform_list += [rgb2bgr]
if not params.use_rgb:
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
transform_list += [rgb2bgr]

dataset = torchvision.datasets.ImageFolder(params.data_path, transform=transforms.Compose(transform_list))
loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size, num_workers=0, shuffle=False)
Expand Down
60 changes: 60 additions & 0 deletions utils/decorrelation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torch.nn as nn



def get_decorrelation_layers(image_size=(224,224), input_mean=[1,1,1], device='cpu', decay_power=0.75):
spatial_mod = SpatialDecorrelationLayer(image_size, decay_power=decay_power, device=device)
transform_mod = TransformLayer(input_mean=input_mean, device=device)
return [spatial_mod, transform_mod]


# Spatial Decorrelation layer based on tensorflow/lucid & greentfrapp/lucent
class SpatialDecorrelationLayer(torch.nn.Module):

def __init__(self, image_size=(224,224), decay_power=1.0, device='cpu'):
super(SpatialDecorrelationLayer, self).__init__()
self.h = image_size[0]
self.w = image_size[1]
self.scale = self.create_scale(decay_power).to(device)

def create_scale(self, decay_power=1.0):
freqs = self.rfft2d_freqs()
self.freqs_shape = freqs.size() + (2,)
scale = 1.0 / torch.max(freqs, torch.full_like(freqs, 1.0 / (max(self.w, self.h)))) ** decay_power
return scale[None, None, ..., None]

def rfft2d_freqs(self):
fy = self.pytorch_fftfreq(self.h)[:, None]
wadd = 2 if self.w % 2 == 1 else 1
fx = self.pytorch_fftfreq(self.w)[: self.w // 2 + wadd]
return torch.sqrt((fx * fx) + (fy * fy))

def pytorch_fftfreq(self, v, d=1.0):
results = torch.empty(v)
s = (v - 1) // 2 + 1
results[:s] = torch.arange(0, s)
results[s:] = torch.arange(-(v // 2), 0)
return results * (1.0 / (v * d))

def ifft_image(self, input):
input = input * self.scale
input = torch.irfft(input, 2, normalized=True, signal_sizes=(self.h, self.w))
return input / 4

def forward(self, input):
return self.ifft_image(input)


# Preprocess input after decorrelation
class TransformLayer(torch.nn.Module):

def __init__(self, input_mean=[1,1,1], r=255, device='cpu'):
super(TransformLayer, self).__init__()
self.input_mean = torch.as_tensor(input_mean).to(device)
self.input_sd = torch.as_tensor([1,1,1]).to(device)
self.r = r

def forward(self, input):
input = torch.sigmoid(input) * self.r
return input.sub(self.input_mean[None, :, None, None]).div(self.input_sd[None, :, None, None])
21 changes: 14 additions & 7 deletions utils/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,11 @@ def forward(self, input):


# Create loss module and hook for multiple channels
def register_hook_batch(net, layer_name, loss_func=mean_loss):
loss_module = SimpleDreamLossHookBatch(loss_func)
def register_hook_batch(net, layer_name, loss_func=mean_loss, neuron=False):
loss_module = SimpleDreamLossHookBatch(loss_func, neuron)
return register_layer_hook(net, layer_name, loss_module)


# Create loss module and hook
def register_simple_hook(net, layer_name, channel=-1, loss_func=mean_loss, mode='loss', neuron=False):
loss_module = SimpleDreamLossHook(channel, loss_func, mode, neuron)
Expand All @@ -171,11 +172,18 @@ def register_layer_hook(net, layer_name, loss_module):

# Define a simple forward hook to collect DeepDream loss for multiple channels
class SimpleDreamLossHookBatch(torch.nn.Module):
def __init__(self, loss_func=mean_loss):
def __init__(self, loss_func=mean_loss, neuron=False):
super(SimpleDreamLossHookBatch, self).__init__()
self.get_loss = loss_func
self.get_neuron = neuron

def extract_neuron(self, input):
x = input.size(2) // 2
y = input.size(3) // 2
return input[:, :, y:y+1, x:x+1]

def forward(self, module, input, output):
output = self.extract_neuron(output) if self.get_neuron == True else output
loss = 0
for batch in range(output.size(0)):
loss = loss + self.get_loss(output[batch, batch])
Expand All @@ -189,9 +197,9 @@ def __init__(self, channel=-1, loss_func=mean_loss, mode='loss', neuron=False):
self.channel = channel
self.get_loss = loss_func
self.mode = mode
self.neuron = neuron
self.get_neuron = neuron

def get_neuron(self, input):
def extract_neuron(self, input):
x = input.size(2) // 2
y = input.size(3) // 2
return input[:, :, y:y+1, x:x+1]
Expand All @@ -204,8 +212,7 @@ def forward_feature(self, input):

def forward(self, module, input, output):
if self.channel > -1:
if self.neuron:
output = self.get_neuron(output)
output = self.extract_neuron(output) if self.get_neuron == True else output
output = output[:,self.channel]
if self.mode == 'loss':
self.forward_loss(output)
Expand Down
29 changes: 21 additions & 8 deletions vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch.nn as nn
import torch.optim as optim

from utils.training_utils import save_csv_data
from utils.inceptionv1_caffe import relu_to_redirected_relu
from utils.vis_utils import preprocess, simple_deprocess, load_model, set_seed, mean_loss, ModelPlus, Jitter, register_simple_hook
from utils.decorrelation import get_decorrelation_layers


def main():
Expand All @@ -18,7 +18,7 @@ def main():
parser.add_argument("-layer", type=str, default='mixed5a')
parser.add_argument("-model_file", type=str, default='')
parser.add_argument("-channel", type=int, default=-1)
parser.add_argument("-center_neuron", action='store_true')
parser.add_argument("-extract_neuron", action='store_true')
parser.add_argument("-image_size", type=str, default='224,224')
parser.add_argument("-content_image", type=str, default='')

Expand All @@ -37,6 +37,8 @@ def main():
parser.add_argument("-not_caffe", action='store_true')
parser.add_argument("-seed", type=int, default=-1)
parser.add_argument("-no_branches", action='store_true')

parser.add_argument("-fft_decorrelation", action='store_true')
params = parser.parse_args()
params.image_size = [int(m) for m in params.image_size.split(',')]
main_func(params)
Expand Down Expand Up @@ -65,29 +67,36 @@ def main_func(params):
# Preprocessing net layers
jit_mod = Jitter(params.jitter)
mod_list = []
if params.fft_decorrelation:
mod_list += get_decorrelation_layers(image_size=params.image_size, input_mean=params.data_mean, device=params.use_device)
mod_list.append(jit_mod)
prep_net = nn.Sequential(*mod_list)

# Full network
net = ModelPlus(prep_net, cnn)
loss_func = mean_loss
loss_modules = register_simple_hook(net.net, params.layer, params.channel, loss_func=loss_func, neuron=params.center_neuron)
loss_modules = register_simple_hook(net.net, params.layer, params.channel, loss_func=loss_func, neuron=params.extract_neuron)

if params.content_image == '':
input_tensor = torch.randn(3,*params.image_size).unsqueeze(0).to(params.use_device) * 0.01
if params.fft_decorrelation:
input_tensor = torch.randn(*((3,) + mod_list[0].freqs_shape)).unsqueeze(0).to(params.use_device) * 0.01
else:
input_tensor = torch.randn(3, *params.image_size).unsqueeze(0).to(params.use_device) * 0.01
else:
input_tensor = preprocess(params.content_image, params.image_size, params.data_mean, params.not_caffe).to(params.use_device)

print('Running optimization with ADAM')

# Create 224x224 image
output_tensor = dream(net, input_tensor, params.num_iterations, params.lr, loss_modules, params.data_mean, \
params.save_iter, params.print_iter, params.output_image, params.not_caffe)
output_tensor = dream(net, input_tensor, params.num_iterations, params.lr, loss_modules, params.save_iter, \
params.print_iter, params.output_image, [params.data_mean, params.not_caffe], mod_list)
if params.fft_decorrelation:
output_tensor = mod_list[1](mod_list[0](output_tensor))
simple_deprocess(output_tensor, params.output_image, params.data_mean, params.not_caffe)


# Function to maximize CNN activations
def dream(net, img, iterations, lr, loss_modules, m, save_iter, print_iter, output_image, not_caffe):
def dream(net, img, iterations, lr, loss_modules, save_iter, print_iter, output_image, deprocess_info, mod_list):
filename, ext = os.path.splitext(output_image)
img = nn.Parameter(img)
optimizer = torch.optim.Adam([img], lr=lr)
Expand All @@ -103,7 +112,11 @@ def dream(net, img, iterations, lr, loss_modules, m, save_iter, print_iter, outp
print('Iteration', str(i) + ',', 'Loss', str(loss.item()))

if save_iter > 0 and i > 0 and i % save_iter == 0:
simple_deprocess(img.detach(), filename + '_' + str(i) + ext, m, not_caffe)
if len(mod_list) > 1:
simple_deprocess(mod_list[1](mod_list[0](img.detach())), filename + '_' + str(i) + \
ext, deprocess_info[0], deprocess_info[1])
else:
simple_deprocess(img.detach(), filename + '_' + str(i) + ext, deprocess_info[0], deprocess_info[1])
optimizer.step()
return img.detach()

Expand Down
Loading

0 comments on commit 55f4ef1

Please sign in to comment.