Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ericlearning committed Aug 11, 2019
1 parent f2f0b29 commit 583086d
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 574 deletions.
20 changes: 1 addition & 19 deletions architectures/architecture_pggan.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,7 @@ def __init__(self, ni, no, ks, stride, pad, act = 'relu', use_bias = True, use_e

self.relu = nn.LeakyReLU(0.2, inplace = True)
if(self.use_equalized_lr):
'''
self.conv = nn.Conv2d(ni, no, ks, stride, pad, bias = False)
kaiming_normal_(self.conv.weight, a = calculate_gain('conv2d'))
self.bias = torch.nn.Parameter(torch.FloatTensor(no).fill_(0))
self.scale = (torch.mean(self.conv.weight.data ** 2)) ** 0.5
self.conv.weight.data.copy_(self.conv.weight.data / self.scale)
'''
self.conv = EqualizedConv(ni, no, ks, stride, pad, use_bias = use_bias)

else:
self.conv = nn.Conv2d(ni, no, ks, stride, pad, bias = use_bias)

Expand All @@ -63,15 +54,7 @@ def __init__(self, ni, no, ks, stride, pad, act = 'relu', use_bias = True, use_e


def forward(self, x):
'''
if(self.use_equalized_lr):
out = self.conv(x * self.scale)
out = out + self.bias.view(1, -1, 1, 1).expand_as(out)
else:
out = self.conv(x)
'''
out = self.conv(x)

if(self.only_conv == False):
if(self.act == 'relu'):
out = self.relu(out)
Expand All @@ -85,7 +68,7 @@ def __init__(self):
super(UpSample, self).__init__()

def forward(self, x):
return F.interpolate(x, None, 2, 'bilinear', align_corners=True)
return F.interpolate(x, None, 2, 'nearest')

class DownSample(nn.Module):
def __init__(self):
Expand All @@ -95,7 +78,6 @@ def forward(self, x):
return F.avg_pool2d(x, 2)

# Progressive Architectures

class Minibatch_Stddev(nn.Module):
def __init__(self):
super(Minibatch_Stddev, self).__init__()
Expand Down
16 changes: 12 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import torch.nn as nn
from dataset import Dataset
from architectures.architecture_pggan import PGGAN_D, PGGAN_G
from trainers.trainer_ralsgan_progressive import Trainer_RALSGAN_Progressive
from trainers.trainer_rahingegan_progressive import Trainer_RAHINGEGAN_Progressive
from trainers.trainer_wgan_gp_progressive import Trainer_WGAN_GP_Progressive
from trainers.trainer import Trainer
from utils import save, load

dir_name = 'data/celeba'
Expand All @@ -20,7 +18,17 @@
netD = PGGAN_D(sz, nc, use_sigmoid, False, True).to(device)
netG = PGGAN_G(sz, nz, nc, True, True).to(device)

trainer = Trainer_RAHINGEGAN_Progressive(netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, loss_interval = 150, image_interval = 300)
trainer = Trainer('SGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('LSGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('HINGEGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('WGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = 0.01, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('WGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = 10, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')

trainer = Trainer('RASGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('RALSGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('RAHINGEGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')

trainer = Trainer('QPGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')

trainer.train([4, 8, 8, 8, 8, 8], [0.5, 0.5, 0.5, 0.5, 0.5], [16, 16, 16, 16, 16, 16])
save('saved/cur_state.state', netD, netG, trainer.optimizerD, trainer.optimizerG)
147 changes: 147 additions & 0 deletions trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os, cv2
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np
import pandas as pd
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from utils import *
from losses.losses import *

class Trainer():
def __init__(self, loss_type, netD, netG, device, train_ds, lr_D = 0.0002, lr_G = 0.0002, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 50, image_interval = 50, save_img_dir = 'saved_images/'):
self.loss_type = loss_type
self.require_type = get_require_type(self.loss_type)
self.loss = get_gan_loss(self.device, self.loss_type)

self.sz = netG.sz
self.netD = netD
self.netG = netG
self.train_ds = train_ds
self.lr_D = lr_D
self.lr_G = lr_G
self.weight_clip = weight_clip
self.use_gradient_penalty = use_gradient_penalty
self.drift = drift
self.device = device
self.resample = resample

self.optimizerD = optim.Adam(self.netD.parameters(), lr = self.lr_D, betas = (0, 0.99))
self.optimizerG = optim.Adam(self.netG.parameters(), lr = self.lr_G, betas = (0, 0.99))

self.real_label = 1
self.fake_label = 0
self.nz = self.netG.nz

self.fixed_noise = generate_noise(49, self.nz, self.device)
self.loss_interval = loss_interval
self.image_interval = image_interval

self.errD_records = []
self.errG_records = []

self.save_cnt = 0
self.save_img_dir = save_img_dir
if(not os.path.exists(self.save_img_dir)):
os.makedirs(self.save_img_dir)

def gradient_penalty(self, real_image, fake_image, p):
bs = real_image.size(0)
alpha = torch.FloatTensor(bs, 1, 1, 1).uniform_(0, 1).expand(real_image.size()).to(self.device)
interpolation = alpha * real_image + (1 - alpha) * fake_image

c_xi = self.netD(interpolation, p)
gradients = autograd.grad(c_xi, interpolation, torch.ones(c_xi.size()).to(self.device),
create_graph = True, retain_graph = True, only_inputs = True)[0]
gradients = gradients.view(bs, -1)
penalty = torch.mean((gradients.norm(2, dim=1) - 1) ** 2)
return penalty

def train(self, res_num_epochs, res_percentage, bs):
p = 0
res_percentage = [None] + res_percentage
for i, (num_epoch, percentage, cur_bs) in enumerate(zip(res_num_epochs, res_percentage, bs)):
train_dl = self.train_ds.get_loader(4 * (2**i), cur_bs)
train_dl_len = len(train_dl)
if(percentage is None):
num_epoch_transition = 0
else:
num_epoch_transition = int(num_epoch * percentage)

cnt = 1
for epoch in range(num_epoch):
p = i

for j, data in enumerate(tqdm(train_dl)):
if(epoch < num_epoch_transition):
p = i + cnt / (train_dl_len * num_epoch_transition) - 1
cnt+=1

self.netD.zero_grad()
real_images = data[0].to(self.device)
bs = real_images.size(0)

noise = generate_noise(bs, self.nz, self.device)
fake_images = self.netG(noise, p)

c_xr = self.netD(real_images, p)
c_xr = c_xr.view(-1)
c_xf = self.netD(fake_images.detach(), p)
c_xf = c_xf.view(-1)

if(self.require_type == 0 or self.require_type == 1):
errD = self.loss.d_loss(c_xr, c_xf)
elif(self.require_type == 2):
errD = self.loss.d_loss(c_xr, c_xf, real_images, fake_images)

if(self.use_gradient_penalty != False):
errD += self.use_gradient_penalty * self.gradient_penalty(real_images, fake_images, p)

if(self.drift != False):
errD += self.drift * torch.mean(c_xr ** 2)

errD.backward()
self.optimizerD.step()

if(self.weight_clip != None):
for param in self.netD.parameters():
param.data.clamp_(-self.weight_clip, self.weight_clip)


self.netG.zero_grad()
if(self.resample):
noise = generate_noise(bs, self.nz, self.device)
fake_images = self.netG(noise, p)

if(self.require_type == 0):
c_xf = self.netD(fake_images, p)
c_xf = c_xf.view(-1)
errG = self.loss.g_loss(c_xf)
if(self.require_type == 1 or self.require_type == 2):
c_xr = self.netD(real_images, p)
c_xr = c_xr.view(-1)
c_xf = self.netD(fake_images, p)
c_xf = c_xf.view(-1)
errG = self.loss.g_loss(c_xr, c_xf)

errG.backward()
self.optimizerG.step()

self.errD_records.append(float(errD))
self.errG_records.append(float(errG))

if(j % self.loss_interval == 0):
print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f'
%(epoch+1, num_epoch, i+1, train_dl_len, errD, errG))

if(j % self.image_interval == 0):
sample_images_list = get_sample_images_list('Progressive', (self.fixed_noise, self.netG, p))
plot_img = get_display_samples(sample_images_list, 7, 7)
cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
self.save_cnt += 1
cv2.imwrite(cur_file_name, plot_img)
132 changes: 0 additions & 132 deletions trainers/trainer_rahingegan_progressive.py

This file was deleted.

Loading

0 comments on commit 583086d

Please sign in to comment.