Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ericlearning committed May 7, 2019
1 parent ff2d1be commit d235a6d
Show file tree
Hide file tree
Showing 8 changed files with 1,031 additions and 0 deletions.
276 changes: 276 additions & 0 deletions architectures/architecture_pggan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as SpectralNorm
from torch.nn.init import kaiming_normal_, calculate_gain

class EqualizedConv(nn.Module):
def __init__(self, ni, no, ks, stride, pad, use_bias):
super(EqualizedConv, self).__init__()
self.ni = ni
self.no = no
self.ks = ks
self.stride = stride
self.pad = pad
self.use_bias = use_bias

self.weight = nn.Parameter(nn.init.normal_(
torch.empty(self.no, self.ni, self.ks, self.ks)
))
if(self.use_bias):
self.bias = nn.Parameter(torch.FloatTensor(self.no).fill_(0))

self.scale = math.sqrt(2 / (self.ni * self.ks * self.ks))

def forward(self, x):
out = F.conv2d(input = x, weight = self.weight * self.scale, bias = self.bias,
stride = self.stride, padding = self.pad)
return out

class ScaledConvBlock(nn.Module):
def __init__(self, ni, no, ks, stride, pad, act = 'relu', use_bias = True, use_equalized_lr = True, use_pixelnorm = True, only_conv = False):
super(ScaledConvBlock, self).__init__()
self.ni = ni
self.no = no
self.ks = ks
self.stride = stride
self.pad = pad
self.act = act
self.use_bias = use_bias
self.use_equalized_lr = use_equalized_lr
self.use_pixelnorm = use_pixelnorm
self.only_conv = only_conv

self.relu = nn.LeakyReLU(0.2, inplace = True)
if(self.use_equalized_lr):
'''
self.conv = nn.Conv2d(ni, no, ks, stride, pad, bias = False)
kaiming_normal_(self.conv.weight, a = calculate_gain('conv2d'))
self.bias = torch.nn.Parameter(torch.FloatTensor(no).fill_(0))
self.scale = (torch.mean(self.conv.weight.data ** 2)) ** 0.5
self.conv.weight.data.copy_(self.conv.weight.data / self.scale)
'''
self.conv = EqualizedConv(ni, no, ks, stride, pad, use_bias = use_bias)

else:
self.conv = nn.Conv2d(ni, no, ks, stride, pad, bias = use_bias)

if(self.use_pixelnorm):
self.pixel_norm = PixelNorm()


def forward(self, x):
'''
if(self.use_equalized_lr):
out = self.conv(x * self.scale)
out = out + self.bias.view(1, -1, 1, 1).expand_as(out)
else:
out = self.conv(x)
'''
out = self.conv(x)

if(self.only_conv == False):
if(self.act == 'relu'):
out = self.relu(out)
if(self.use_pixelnorm):
out = self.pixel_norm(out)

return out

class UpSample(nn.Module):
def __init__(self):
super(UpSample, self).__init__()

def forward(self, x):
return F.interpolate(x, None, 2, 'bilinear', align_corners=True)

class DownSample(nn.Module):
def __init__(self):
super(DownSample, self).__init__()

def forward(self, x):
return F.avg_pool2d(x, 2)

# Progressive Architectures

class Minibatch_Stddev(nn.Module):
def __init__(self):
super(Minibatch_Stddev, self).__init__()

def forward(self, x):
stddev = torch.sqrt(torch.mean((x - torch.mean(x, dim = 0, keepdim = True))**2, dim = 0, keepdim = True) + 1e-8)
stddev_mean = torch.mean(stddev, dim = 1, keepdim = True)
stddev_mean = stddev_mean.expand((x.size(0), 1, x.size(2), x.size(3)))

return torch.cat([x, stddev_mean], dim = 1)

class PixelNorm(nn.Module):
def __init__(self):
super(PixelNorm, self).__init__()

def forward(self, x):
out = x / torch.sqrt(torch.mean(x**2, dim = 1, keepdim = True) + 1e-8)
return out

class PGGAN_G(nn.Module):
def __init__(self, sz, nz, nc, use_pixelnorm = False, use_equalized_lr = False, use_tanh = True):
super(PGGAN_G, self).__init__()
self.sz = sz
self.nz = nz
self.nc = nc
self.ngfs = {
'8': [32, 16],
'16': [64, 32, 16],
'32': [128, 64, 32, 16],
'64': [256, 128, 64, 32, 16],
'128': [512, 256, 128, 64, 32, 16],
'256': [512, 512, 256, 128, 64, 32, 16],
'512': [512, 512, 512, 256, 128, 64, 32, 16],
'1024': [512, 512, 512, 512, 256, 128, 64, 32, 16]
}

self.cur_ngf = self.ngfs[str(sz)]

# create blocks list
prev_dim = self.cur_ngf[0]
cur_block = nn.Sequential(
ScaledConvBlock(nz, prev_dim, 4, 1, 3, 'relu', True, use_equalized_lr, use_pixelnorm),
ScaledConvBlock(prev_dim, prev_dim, 3, 1, 1, 'relu', True, use_equalized_lr, use_pixelnorm)
)
self.blocks = nn.ModuleList([cur_block])
for dim in self.cur_ngf[1:]:
cur_block = nn.Sequential(
ScaledConvBlock(prev_dim, dim, 3, 1, 1, 'relu', True, use_equalized_lr, use_pixelnorm),
ScaledConvBlock(dim, dim, 3, 1, 1, 'relu', True, use_equalized_lr, use_pixelnorm)
)
prev_dim = dim
self.blocks.append(cur_block)

# create to_blocks list
self.to_blocks = nn.ModuleList([])
for dim in self.cur_ngf:
self.to_blocks.append(ScaledConvBlock(dim, nc, 1, 1, 0, None, True, use_equalized_lr, use_pixelnorm, only_conv = True))

self.use_tanh = use_tanh
self.tanh = nn.Tanh()
self.upsample = UpSample()

def forward(self, x, stage):
stage_int = int(stage)
stage_type = (stage == stage_int)
out = x

# Stablization Steps
if(stage_type):
for i in range(stage_int):
out = self.blocks[i](out)
out = self.upsample(out)
out = self.blocks[stage_int](out)
out = self.to_blocks[stage_int](out)

# Growing Steps
else:
p = stage - stage_int
for i in range(stage_int+1):
out = self.blocks[i](out)
out = self.upsample(out)

out_1 = self.to_blocks[stage_int](out)
out_2 = self.blocks[stage_int+1](out)
out_2 = self.to_blocks[stage_int+1](out_2)
out = out_1 * (1 - p) + out_2 * p

if(self.use_tanh):
out = self.tanh(out)

return out

class PGGAN_D(nn.Module):
def __init__(self, sz, nc, use_sigmoid = True, use_pixelnorm = False, use_equalized_lr = False):
super(PGGAN_D, self).__init__()
self.sz = sz
self.nc = nc
self.sigmoid = nn.Sigmoid()
self.use_sigmoid = use_sigmoid
self.ndfs = {
'8': [32, 16],
'16': [64, 32, 16],
'32': [128, 64, 32, 16],
'64': [256, 128, 64, 32, 16],
'128': [512, 256, 128, 64, 32, 16],
'256': [512, 512, 256, 128, 64, 32, 16],
'512': [512, 512, 512, 256, 128, 64, 32, 16],
'1024': [512, 512, 512, 512, 256, 128, 64, 32, 16]
}

self.cur_ndf = self.ndfs[str(sz)]

# create blocks list
prev_dim = self.cur_ndf[0]
cur_block = nn.Sequential(
Minibatch_Stddev(),
ScaledConvBlock(prev_dim+1, prev_dim, 3, 1, 1, 'relu', True, use_equalized_lr, use_pixelnorm),
ScaledConvBlock(prev_dim, prev_dim, 4, 1, 0, 'relu', True, use_equalized_lr, use_pixelnorm)
)
self.blocks = nn.ModuleList([cur_block])
for dim in self.cur_ndf[1:]:
cur_block = nn.Sequential(
ScaledConvBlock(dim, dim, 3, 1, 1, 'relu', True, use_equalized_lr, use_pixelnorm),
ScaledConvBlock(dim, prev_dim, 3, 1, 1, 'relu', True, use_equalized_lr, use_pixelnorm)
)
prev_dim = dim
self.blocks.append(cur_block)

# create from_blocks list
self.from_blocks = nn.ModuleList([])
for dim in self.cur_ndf:
self.from_blocks.append(ScaledConvBlock(nc, dim, 1, 1, 0, None, True, use_equalized_lr, use_pixelnorm, only_conv = True))

self.linear = nn.Linear(self.cur_ndf[0], 1)
self.downsample = DownSample()

def forward(self, x, stage):
stage_int = int(stage)
stage_type = (stage == stage_int)
sz = 2 ** (2+stage_int)
if(stage_type == False):
sz *= 2
out = F.adaptive_avg_pool2d(x, sz)

# Stablization Steps
if(stage_type):
out = self.from_blocks[stage_int](out)
for i in range(stage_int):
out = self.blocks[stage_int - i](out)
out = self.downsample(out)
out = self.blocks[0](out)
out = self.linear(out.view(out.shape[0], -1))
out = out.view(out.shape[0], 1, 1, 1)

# Growing Steps
else:
p = stage - stage_int
out_1 = self.downsample(out)
out_1 = self.from_blocks[stage_int](out_1)

out_2 = self.from_blocks[stage_int+1](out)
out_2 = self.blocks[stage_int+1](out_2)
out_2 = self.downsample(out_2)

out = out_1 * (1 - p) + out_2 * p

for i in range(stage_int):
out = self.blocks[stage_int - i](out)
out = self.downsample(out)
out = self.blocks[0](out)
out = self.linear(out.view(out.shape[0], -1))
out = out.view(out.shape[0], 1, 1, 1)

if(self.use_sigmoid):
out = self.sigmoid(out)

return out

71 changes: 71 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import torch
import random
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image

class Dataset():
def __init__(self, train_dir, basic_types = None, shuffle = True):
self.train_dir = train_dir
self.basic_types = basic_types
self.shuffle = shuffle

def get_loader(self, sz, bs, get_size = False, data_transform = None, num_workers = 1, audio_sample_num = None):
if(self.basic_types is None):
if(data_transform == None):
data_transform = transforms.Compose([
transforms.Resize(sz),
transforms.CenterCrop(sz),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_dataset = datasets.ImageFolder(self.train_dir, data_transform)
train_loader = DataLoader(train_dataset, batch_size = bs, shuffle = self.shuffle, num_workers = num_workers)

train_dataset_size = len(train_dataset)
size = train_dataset_size

returns = (train_loader)
if(get_size):
returns = returns + (size,)

elif(self.basic_types == 'MNIST'):
data_transform = transforms.Compose([
transforms.Resize(sz),
transforms.CenterCrop(sz),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(self.train_dir, train = True, download = True, transform = data_transform)
train_loader = DataLoader(train_dataset, batch_size = bs, shuffle = self.shuffle, num_workers = num_workers)

train_dataset_size = len(train_dataset)
size = train_dataset_size

returns = (train_loader)
if(get_size):
returns = returns + (size,)

elif(self.basic_types == 'CIFAR10'):
data_transform = transforms.Compose([
transforms.Resize(sz),
transforms.CenterCrop(sz),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_dataset = datasets.CIFAR10(self.train_dir, train = True, download = True, transform = data_transform)
train_loader = DataLoader(train_dataset, batch_size = bs, shuffle = self.shuffle, num_workers = num_workers)

train_dataset_size = len(train_dataset)
size = train_dataset_size

returns = (train_loader)
if(get_size):
returns = returns + (size,)

return returns
26 changes: 26 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import torch
import torch.nn as nn
from dataset import Dataset
from architectures.architectures_pggan import PGGAN_D, PGGAN_G
from trainers.trainer_ralsgan_progressive import Trainer_RALSGAN_Progressive
from trainers.trainer_rahingegan_progressive import Trainer_RAHINGEGAN_Progressive
from trainers.trainer_wgan_gp_progressive import Trainer_WGAN_GP_Progressive
from utils import save, load

dir_name = 'data/celeba'
basic_types = None

lr_D, lr_G = 0.001, 0.001
sz, nc, nz = 128, 3, 256
use_sigmoid = False

data = Dataset('data/celeba')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
netD = PGGAN_D(sz, nc, use_sigmoid, False, True).to(device)
netG = PGGAN_G(sz, nz, nc, True, True).to(device)

trainer = Trainer_RAHINGEGAN_Progressive(netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, loss_interval = 150, image_interval = 300)

trainer.train([4, 8, 8, 8, 8, 8], [0.5, 0.5, 0.5, 0.5, 0.5], [16, 16, 16, 16, 16, 16])
save('saved/cur_state.state', netD, netG, trainer.optimizerD, trainer.optimizerG)
Loading

0 comments on commit d235a6d

Please sign in to comment.