Skip to content

Commit

Permalink
updated the code for saving the model
Browse files Browse the repository at this point in the history
  • Loading branch information
rubelchowdhury20 committed Nov 3, 2020
1 parent 701f1b9 commit 3a9d016
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 18 deletions.
Binary file modified __pycache__/train.cpython-38.pyc
Binary file not shown.
15 changes: 8 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,21 @@ def main(args):
parser.add_argument("--data_directory", type=str, default="/media/tensor/EXTDRIVE/projects/virtual-try-on/dataset/zalando_final/", help="path to the directory having images for training.")
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')


# for displays
parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console')
parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
parser.add_argument('--save_epoch_freq', type=int, default=100, help='frequency of saving checkpoints at the end of epochs')
parser.add_argument('--display_freq', type=int, default=10, help='frequency of showing training results on screen')
parser.add_argument('--print_freq', type=int, default=50, help='frequency of showing training results on console')
parser.add_argument('--save_latest_freq', type=int, default=100, help='frequency of saving the latest results')
parser.add_argument('--save_epoch_freq', type=int, default=2, help='frequency of saving checkpoints at the end of epochs')


# for generator
parser.add_argument('--netG_input_nc', type=int, default=16, help="# of input channels to the generator")
parser.add_argument('--ngf', type=int, default=16, help='# of gen filters in first conv layer')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
parser.add_argument('--n_downsample_global', type=int, default=1, help='number of downsampling layers in netG')
parser.add_argument('--n_downsample_global', type=int, default=3, help='number of downsampling layers in netG')
parser.add_argument('--n_blocks_global', type=int, default=1, help='number of residual blocks in the global generator network')
parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
Expand All @@ -85,7 +86,7 @@ def main(args):
# for discriminators
parser.add_argument('--num_D', type=int, default=1, help='number of discriminators to use')
parser.add_argument('--n_layers_D', type=int, default=1, help='only used if which_model_netD==n_layers')
parser.add_argument('--ndf', type=int, default=16, help='# of discrim filters in first conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
Expand Down
Binary file modified modules/models/__pycache__/create_model.cpython-38.pyc
Binary file not shown.
Binary file modified modules/models/__pycache__/feature_net.cpython-38.pyc
Binary file not shown.
Binary file modified modules/models/__pycache__/pix2pixHD_model.cpython-38.pyc
Binary file not shown.
24 changes: 20 additions & 4 deletions modules/models/create_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,31 @@
from . import feature_net
from . import feature_render
from . import pix2pixHD_model
from .base_model import BaseModel

class CreateModel(nn.Module):
class CreateModel(BaseModel):
def __init__(self, config):
super(CreateModel, self).__init__()
self.config = config

BaseModel.initialize(self, self.config.args)

self.feature_net = feature_net.FeatureNet(num_classes=self.config.args.netG_input_nc, up_mode="upsample").to(self.config.DEVICE)
self.feature_render = feature_render.FeatureRender(self.config).to(self.config.DEVICE)
self.render_net = pix2pixHD_model.Pix2PixHDModel(config.args).to(self.config.DEVICE)
self.render_net = pix2pixHD_model.Pix2PixHDModel(self.config.args).to(self.config.DEVICE)
if config.args.is_train and len(config.args.gpu_ids):
self.feature_net = torch.nn.DataParallel(self.feature_net, device_ids=config.args.gpu_ids)
self.feature_render = torch.nn.DataParallel(self.feature_render, device_ids=config.args.gpu_ids)
self.render_net = torch.nn.DataParallel(self.render_net, device_ids=config.args.gpu_ids)

# load networks
if not self.config.args.is_train or self.config.args.continue_train or self.config.args.load_pretrain:
pretrained_path = '' if not self.config.args.is_train else self.config.args.load_pretrain
self.load_network(self.feature_net, 'Feature', self.config.args.which_epoch, pretrained_path)


self.optimizer_G = self.render_net.module.optimizer_G
self.optimizer_D = self.render_net.module.optimizer_D
# self.save = self.render_net.module

def forward(self, batch):
source_image = batch[0].to(self.config.DEVICE)
Expand All @@ -38,4 +50,8 @@ def forward(self, batch):

loss_D = loss_D_fake + loss_D_real

return feature_loss, loss_D, loss_G_GAN, loss_G_VGG
return feature_loss, loss_D, loss_G_GAN, loss_G_VGG

def save_feature_net(self, which_epoch):
self.save_network(self.feature_net, 'Feature', which_epoch, self.config.args.gpu_ids)

12 changes: 9 additions & 3 deletions modules/models/feature_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch.nn as nn
import torch.nn.functional as F

# local imports
from .base_model import BaseModel


#---------------------- basic convolution blocks ------------------------#
def conv1x1(in_channels, out_channels, groups=1):
Expand Down Expand Up @@ -106,7 +109,7 @@ def forward(self, from_down, from_up):


class FeatureNet(nn.Module):
def __init__(self, num_classes, in_channels=3, depth=1,
def __init__(self, num_classes, in_channels=3, depth=5,
start_filts=64, up_mode='transpose',
merge_mode='concat'):
"""
Expand Down Expand Up @@ -189,7 +192,6 @@ def weight_init(m):
nn.init.xavier_normal(m.weight)
nn.init.constant(m.bias, 0)


def reset_params(self):
for i, m in enumerate(self.modules()):
self.weight_init(m)
Expand All @@ -212,4 +214,8 @@ def forward(self, x):
# as this module includes a softmax already.
x = self.conv_final(x)
loss = self.abs_loss(input_,x[:,:3,:,:]) + self.mse_loss(input_, x[:,:3,:,:])
return x, loss
return x, loss

def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.cfg.gpu_ids)
self.save_network(self.netD, 'D', which_epoch, self.cfg.gpu_ids)
2 changes: 1 addition & 1 deletion modules/models/pix2pixHD_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, cfg):
if not self.cfg.is_train or self.cfg.continue_train or self.cfg.load_pretrain:
pretrained_path = '' if not self.cfg.is_train else self.cfg.load_pretrain
self.load_network(self.netG, 'G', self.cfg.which_epoch, pretrained_path)
if self.is_train:
if self.cfg.is_train:
self.load_network(self.netD, 'D', self.cfg.which_epoch, pretrained_path)


Expand Down
12 changes: 9 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def train(config):
### save latest model
if total_steps % config.args.save_latest_freq == save_delta:
print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
model.module.render_net.module.save('latest')
model.module.render_net.module.save('latest')
model.module.save_feature_net("latest")
np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

##################################################################################
Expand Down Expand Up @@ -156,8 +157,13 @@ def train(config):
### save model for this epoch
if epoch % config.args.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
model.module.save('latest')
model.module.save(epoch)

model.module.render_net.module.save('latest')
model.module.save_feature_net("latest")

model.module.render_net.module.save(epoch)
model.module.save_feature_net(epoch)

np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')

### linearly decay learning rate after certain iterations
Expand Down

0 comments on commit 3a9d016

Please sign in to comment.