Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ericlearning committed Jun 22, 2019
1 parent 7a02ae4 commit 914cf13
Show file tree
Hide file tree
Showing 7 changed files with 348 additions and 35 deletions.
123 changes: 123 additions & 0 deletions losses/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch
import torch.nn as nn
import numpy as np

def get_label(bs, device):
label_r = torch.full((bs, ), 1, device = device)
label_f = torch.full((bs, ), 0, device = device)
return label_r, label_f

class SGAN(nn.Module):
def __init__(self, device):
super(SGAN, self).__init__()
self.criterion = nn.BCELoss()
self.device = device

def d_loss(self, c_xr, c_xf):
bs = c_xf.shape[0]
label_r, label_f = get_label(bs, self.device)
return self.criterion(c_xr, label_r) + self.criterion(c_xf, label_f)

def g_loss(self, c_xf):
bs = c_xf.shape[0]
label_r, _ = get_label(bs, self.device)
return self.criterion(c_xf, label_r)

class LSGAN(nn.Module):
def __init__(self, device):
super(LSGAN, self).__init__()
self.device = device

def d_loss(self, c_xr, c_xf):
bs = c_xf.shape[0]
label_r, label_f = get_label(bs, self.device)
return torch.mean((c_xr - label_r) ** 2) + torch.mean((c_xf - label_f) ** 2)

def g_loss(self, c_xf):
bs = c_xf.shape[0]
label_r, _ = get_label(bs, self.device)
return torch.mean((c_xf - label_r) ** 2)

class HINGEGAN(nn.Module):
def __init__(self, device):
super(HINGEGAN, self).__init__()
self.device = device

def d_loss(self, c_xr, c_xf):
bs = c_xf.shape[0]
return torch.mean(torch.nn.ReLU()(1-c_xr)) + torch.mean(torch.nn.ReLU()(1+c_xf))

def g_loss(self, c_xf):
return -torch.mean(c_xf)

class WGAN(nn.Module):
def __init__(self, device):
super(WGAN, self).__init__()
self.device = device

def d_loss(self, c_xr, c_xf):
return -torch.mean(c_xr) + torch.mean(c_xf)

def g_loss(self, c_xf):
return -torch.mean(c_xf)

class RASGAN(nn.Module):
def __init__(self, device):
super(RASGAN, self).__init__()
self.device = device
self.criterion = nn.BCEWithLogitsLoss()

def d_loss(self, c_xr, c_xf):
bs = c_xf.shape[0]
label_r, label_f = get_label(bs, self.device)
return (self.criterion(c_xr - torch.mean(c_xf), label_r) + self.criterion(c_xf - torch.mean(c_xr), label_f)) / 2.0

def g_loss(self, c_xr, c_xf):
bs = c_xf.shape[0]
label_r, label_f = get_label(bs, self.device)
return (self.criterion(c_xr - torch.mean(c_xf), label_f) + self.criterion(c_xf - torch.mean(c_xr), label_r)) / 2.0

class RALSGAN(nn.Module):
def __init__(self, device):
super(RALSGAN, self).__init__()
self.device = device

def d_loss(self, c_xr, c_xf):
bs = c_xf.shape[0]
label_r, _ = get_label(bs, self.device)
return (torch.mean((c_xr - torch.mean(c_xf) - label_r)**2) + torch.mean((c_xf - torch.mean(c_xr) + label_r)**2)) / 2.0

def g_loss(self, c_xr, c_xf):
bs = c_xf.shape[0]
label_r, _ = get_label(bs, self.device)
return (torch.mean((c_xf - torch.mean(c_xr) - label_r)**2) + torch.mean((c_xr - torch.mean(c_xf) + label_r)**2)) / 2.0

class RAHINGEGAN(nn.Module):
def __init__(self, device):
super(RAHINGEGAN, self).__init__()
self.device = device

def d_loss(self, c_xr, c_xf):
return (torch.mean(torch.nn.ReLU()(1-(c_xr-torch.mean(c_xf)))) + torch.mean(torch.nn.ReLU()(1+(c_xf-torch.mean(c_xr))))) / 2.0

def g_loss(self, c_xr, c_xf):
return (torch.mean(torch.nn.ReLU()(1-(c_xf-torch.mean(c_xr)))) + torch.mean(torch.nn.ReLU()(1+(c_xr-torch.mean(c_xf))))) / 2.0

class QPGAN(nn.Module):
def __init__(self, device, norm_type = 'L1'):
super(QPGAN, self).__init__()
self.device = device
self.norm_type = norm_type

def d_loss(self, c_xr, c_xf, real_images, fake_images):
if(self.norm_type == 'L1'):
denominator = (real_images - fake_images).abs().mean() * 10 * 2
if(self.norm_type == 'L2'):
denominator = (real_images - fake_images).mean().sqrt() * 10 * 2

errD_1 = torch.mean(c_xr) - torch.mean(c_xf)
errD_2 = (errD_1 ** 2) / denominator
return errD_2 - errD_1

def g_loss(self, c_xr, c_xf):
return torch.mean(c_xr) - torch.mean(c_xf)
34 changes: 34 additions & 0 deletions train_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import torch
import torch.nn as nn
from dataset import Dataset
from architectures.architectures_pggan import PGGAN_D, PGGAN_G
from trainers_advanced.trainer import Trainer
from utils import save, load

dir_name = 'data/celeba'
basic_types = None

lr_D, lr_G = 0.001, 0.001
sz, nc, nz = 128, 3, 256
use_sigmoid = False

data = Dataset('data/celeba')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
netD = PGGAN_D(sz, nc, use_sigmoid, False, True).to(device)
netG = PGGAN_G(sz, nz, nc, True, True).to(device)

trainer = Trainer('SGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('LSGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('HINGEGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('WGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = 0.01, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('WGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = 10, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')

trainer = Trainer('RASGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('RALSGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')
trainer = Trainer('RAHINGEGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')

trainer = Trainer('QPGAN', netD, netG, device, data, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')

trainer.train([4, 8, 8, 8, 8, 8], [0.5, 0.5, 0.5, 0.5, 0.5], [16, 16, 16, 16, 16, 16])
save('saved/cur_state.state', netD, netG, trainer.optimizerD, trainer.optimizerG)
17 changes: 8 additions & 9 deletions trainers/trainer_rahingegan_progressive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import os, cv2
import copy
import torch
import torch.nn as nn
Expand All @@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from utils import set_lr, get_lr, generate_noise, plot_multiple_images, save_fig, save, get_sample_images_list
from utils import set_lr, get_lr, generate_noise, plot_multiple_images, save_fig, save, get_sample_images_list, get_display_samples

class Trainer_RAHINGEGAN_Progressive():
def __init__(self, netD, netG, device, train_ds, lr_D = 0.0004, lr_G = 0.0001, drift = 0.001, loss_interval = 50, image_interval = 50, snapshot_interval = None, save_img_dir = 'saved_images/', save_snapshot_dir = 'saved_snapshots', resample = False):
Expand All @@ -29,7 +29,7 @@ def __init__(self, netD, netG, device, train_ds, lr_D = 0.0004, lr_G = 0.0001, d
self.fake_label = 0
self.nz = self.netG.nz

self.fixed_noise = generate_noise(16, self.nz, self.device)
self.fixed_noise = generate_noise(49, self.nz, self.device)
self.loss_interval = loss_interval
self.image_interval = image_interval
self.snapshot_interval = snapshot_interval
Expand All @@ -49,14 +49,14 @@ def train(self, res_num_epochs, res_percentage, bs):
p = 0
res_percentage = [None] + res_percentage
for i, (num_epoch, percentage, cur_bs) in enumerate(zip(res_num_epochs, res_percentage, bs)):
train_dl = self.train_ds.get_loader(self.sz, cur_bs)
train_dl = self.train_ds.get_loader(4 * (2**i), cur_bs)
train_dl_len = len(train_dl)
if(percentage is None):
num_epoch_transition = 0
else:
num_epoch_transition = int(num_epoch * percentage)

cnt = 0
cnt = 1
for epoch in range(num_epoch):
p = i
if(self.resample):
Expand Down Expand Up @@ -114,12 +114,11 @@ def train(self, res_num_epochs, res_percentage, bs):

if(j % self.image_interval == 0):
sample_images_list = get_sample_images_list('Progressive', (self.fixed_noise, self.netG, p))
plot_fig = plot_multiple_images(sample_images_list, 4, 4)
plot_img = get_display_samples(sample_images_list, 7, 7)
cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
self.save_cnt += 1
save_fig(cur_file_name, plot_fig)
plot_fig.clf()

cv2.imwrite(cur_file_name, plot_img)

if(self.snapshot_interval is not None):
if(j % self.snapshot_interval == 0):
stage_int = int(p)
Expand Down
17 changes: 8 additions & 9 deletions trainers/trainer_ralsgan_progressive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import os, cv2
import copy
import torch
import torch.nn as nn
Expand All @@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from utils import set_lr, get_lr, generate_noise, plot_multiple_images, save_fig, save, get_sample_images_list
from utils import set_lr, get_lr, generate_noise, plot_multiple_images, save_fig, save, get_sample_images_list, get_display_samples

class Trainer_RALSGAN_Progressive():
def __init__(self, netD, netG, device, train_ds, lr_D = 0.0002, lr_G = 0.0002, drift = 0.001, loss_interval = 50, image_interval = 50, snapshot_interval = None, save_img_dir = 'saved_images/', save_snapshot_dir = 'saved_snapshots', resample = False):
Expand All @@ -24,12 +24,12 @@ def __init__(self, netD, netG, device, train_ds, lr_D = 0.0002, lr_G = 0.0002, d

self.optimizerD = optim.Adam(self.netD.parameters(), lr = self.lr_D, betas = (0, 0.99))
self.optimizerG = optim.Adam(self.netG.parameters(), lr = self.lr_G, betas = (0, 0.99))

self.real_label = 1
self.fake_label = 0
self.nz = self.netG.nz

self.fixed_noise = generate_noise(16, self.nz, self.device)
self.fixed_noise = generate_noise(49, self.nz, self.device)
self.loss_interval = loss_interval
self.image_interval = image_interval
self.snapshot_interval = snapshot_interval
Expand All @@ -49,14 +49,14 @@ def train(self, res_num_epochs, res_percentage, bs):
p = 0
res_percentage = [None] + res_percentage
for i, (num_epoch, percentage, cur_bs) in enumerate(zip(res_num_epochs, res_percentage, bs)):
train_dl = self.train_ds.get_loader(self.sz, cur_bs)
train_dl = self.train_ds.get_loader(4 * (2**i), cur_bs)
train_dl_len = len(train_dl)
if(percentage is None):
num_epoch_transition = 0
else:
num_epoch_transition = int(num_epoch * percentage)

cnt = 0
cnt = 1
for epoch in range(num_epoch):
p = i
if(self.resample):
Expand Down Expand Up @@ -112,11 +112,10 @@ def train(self, res_num_epochs, res_percentage, bs):

if(j % self.image_interval == 0):
sample_images_list = get_sample_images_list('Progressive', (self.fixed_noise, self.netG, p))
plot_fig = plot_multiple_images(sample_images_list, 4, 4)
plot_img = get_display_samples(sample_images_list, 7, 7)
cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
self.save_cnt += 1
save_fig(cur_file_name, plot_fig)
plot_fig.clf()
cv2.imwrite(cur_file_name, plot_img)

if(self.snapshot_interval is not None):
if(j % self.snapshot_interval == 0):
Expand Down
17 changes: 8 additions & 9 deletions trainers/trainer_rasgan_progressive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import os, cv2
import copy
import torch
import torch.nn as nn
Expand All @@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from utils import set_lr, get_lr, generate_noise, plot_multiple_images, save_fig, save, get_sample_images_list
from utils import set_lr, get_lr, generate_noise, plot_multiple_images, save_fig, save, get_sample_images_list, get_display_samples

class Trainer_RASGAN_Progressive():
def __init__(self, netD, netG, device, train_ds, lr_D = 0.0002, lr_G = 0.0002, drift = 0.001, loss_interval = 50, image_interval = 50, snapshot_interval = None, save_img_dir = 'saved_images/', save_snapshot_dir = 'saved_snapshots', resample = False):
Expand All @@ -29,7 +29,7 @@ def __init__(self, netD, netG, device, train_ds, lr_D = 0.0002, lr_G = 0.0002, d
self.fake_label = 0
self.nz = self.netG.nz

self.fixed_noise = generate_noise(16, self.nz, self.device)
self.fixed_noise = generate_noise(49, self.nz, self.device)
self.loss_interval = loss_interval
self.image_interval = image_interval
self.snapshot_interval = snapshot_interval
Expand All @@ -50,14 +50,14 @@ def train(self, res_num_epochs, res_percentage, bs):
criterion = nn.BCEWithLogitsLoss()
res_percentage = [None] + res_percentage
for i, (num_epoch, percentage, cur_bs) in enumerate(zip(res_num_epochs, res_percentage, bs)):
train_dl = self.train_ds.get_loader(self.sz, cur_bs)
train_dl = self.train_ds.get_loader(4 * (2**i), cur_bs)
train_dl_len = len(train_dl)
if(percentage is None):
num_epoch_transition = 0
else:
num_epoch_transition = int(num_epoch * percentage)

cnt = 0
cnt = 1
for epoch in range(num_epoch):
p = i
if(self.resample):
Expand Down Expand Up @@ -114,12 +114,11 @@ def train(self, res_num_epochs, res_percentage, bs):

if(j % self.image_interval == 0):
sample_images_list = get_sample_images_list('Progressive', (self.fixed_noise, self.netG, p))
plot_fig = plot_multiple_images(sample_images_list, 4, 4)
plot_img = get_display_samples(sample_images_list, 7, 7)
cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
self.save_cnt += 1
save_fig(cur_file_name, plot_fig)
plot_fig.clf()

cv2.imwrite(cur_file_name, plot_img)

if(self.snapshot_interval is not None):
if(j % self.snapshot_interval == 0):
stage_int = int(p)
Expand Down
15 changes: 7 additions & 8 deletions trainers/trainer_wgan_gp_progressive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import os, cv2
import copy
import torch
import torch.nn as nn
Expand All @@ -10,7 +10,7 @@
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from utils import set_lr, get_lr, generate_noise, plot_multiple_images, save_fig, save, get_sample_images_list
from utils import set_lr, get_lr, generate_noise, plot_multiple_images, save_fig, save, get_sample_images_list, get_display_samples

class Trainer_WGAN_GP_Progressive():
def __init__(self, netD, netG, device, train_ds, lr_D = 0.0002, lr_G = 0.0002, n_critic = 5, lambd = 10, drift = 0.001, loss_interval = 50, image_interval = 50, snapshot_interval = None, save_img_dir = 'saved_images/', save_snapshot_dir = 'saved_snapshots', resample = None):
Expand All @@ -33,7 +33,7 @@ def __init__(self, netD, netG, device, train_ds, lr_D = 0.0002, lr_G = 0.0002, n
self.fake_label = 0
self.nz = self.netG.nz

self.fixed_noise = generate_noise(16, self.nz, self.device)
self.fixed_noise = generate_noise(49, self.nz, self.device)
self.loss_interval = loss_interval
self.image_interval = image_interval
self.snapshot_interval = snapshot_interval
Expand All @@ -56,14 +56,14 @@ def train(self, res_num_epochs, res_percentage, bs):
p = 0
res_percentage = [None] + res_percentage
for i, (num_epoch, percentage, cur_bs) in enumerate(zip(res_num_epochs, res_percentage, bs)):
train_dl = self.train_ds.get_loader(self.sz, cur_bs)
train_dl = self.train_ds.get_loader(4 * (2**i), cur_bs)
train_dl_len = len(train_dl)
if(percentage is None):
num_epoch_transition = 0
else:
num_epoch_transition = int(num_epoch * percentage)

cnt = 0
cnt = 1
for epoch in range(num_epoch):
p = i

Expand Down Expand Up @@ -122,11 +122,10 @@ def train(self, res_num_epochs, res_percentage, bs):

if(j % self.image_interval == 0):
sample_images_list = get_sample_images_list('Progressive', (self.fixed_noise, self.netG, p))
plot_fig = plot_multiple_images(sample_images_list, 4, 4)
plot_img = get_display_samples(sample_images_list, 7, 7)
cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
self.save_cnt += 1
save_fig(cur_file_name, plot_fig)
plot_fig.clf()
cv2.imwrite(cur_file_name, plot_img)

if(self.snapshot_interval is not None):
if(j % self.snapshot_interval == 0):
Expand Down
Loading

0 comments on commit 914cf13

Please sign in to comment.