-
Notifications
You must be signed in to change notification settings - Fork 24
/
train.py
245 lines (197 loc) · 8.59 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import argparse
import os
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data as data
import numpy as np
from PIL import Image
from PIL import ImageFile
from torchvision import transforms
from torchvision.utils import save_image
import time
import net
from sampler import InfiniteSamplerWrapper
from math import log, sqrt, pi
cudnn.benchmark = True
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
def train_transform():
transform_list = [
transforms.Resize(size=(512, 512)),
transforms.RandomCrop(256),
transforms.ToTensor()
]
return transforms.Compose(transform_list)
class FlatFolderDataset(data.Dataset):
def __init__(self, root, transform):
super(FlatFolderDataset, self).__init__()
self.root = root
self.paths = os.listdir(self.root)
self.transform = transform
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(os.path.join(self.root, path)).convert('RGB')
img = self.transform(img)
return img
def __len__(self):
return len(self.paths)
def name(self):
return 'FlatFolderDataset'
def adjust_learning_rate(optimizer, iteration_count):
"""Imitating the original implementation"""
lr = args.lr / (1.0 + args.lr_decay * iteration_count)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--content_dir', type=str, required=True,
help='Directory path to a batch of content images')
parser.add_argument('--style_dir', type=str, required=True,
help='Directory path to a batch of style images')
parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth')
# training options
parser.add_argument('--save_dir', default='experiments',
help='Directory to save the model')
parser.add_argument('--log_dir', default='./logs',
help='Directory to save the log')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--lr_decay', type=float, default=5e-5)
parser.add_argument('--max_iter', type=int, default=160000)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--mse_weight', type=float, default=0)
parser.add_argument('--style_weight', type=float, default=1)
parser.add_argument('--content_weight', type=float, default=0.1)
# save options
parser.add_argument('--n_threads', type=int, default=8)
parser.add_argument('--print_interval', type=int, default=100)
parser.add_argument('--save_model_interval', type=int, default=5000)
parser.add_argument('--start_iter', type=int, default=0, help='starting iteration')
parser.add_argument('--resume', default='glow.pth', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
# glow parameters
parser.add_argument('--n_flow', default=8, type=int, help='number of flows in each block')# 32
parser.add_argument('--n_block', default=2, type=int, help='number of blocks')# 4
parser.add_argument('--no_lu', action='store_true', help='use plain convolution instead of LU decomposed version')
parser.add_argument('--affine', default=False, type=bool, help='use affine coupling instead of additive')
parser.add_argument('--operator', type=str, default='adain',
help='style feature transfer operator')
args = parser.parse_args()
if args.operator == 'wct':
from glow_wct import Glow
elif args.operator == 'adain':
from glow_adain import Glow
elif args.operator == 'decorator':
from glow_decorator import Glow
else:
raise('Not implemented operator', args.operator)
device = torch.device('cuda')
if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
args.resume = os.path.join(args.save_dir, args.resume)
if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir)
# VGG
vgg = net.vgg
vgg.load_state_dict(torch.load(args.vgg))
encoder = net.Net(vgg)
encoder = nn.DataParallel(encoder)
encoder.to(device)
# glow
glow_single = Glow(3, args.n_flow, args.n_block, affine=args.affine, conv_lu=not args.no_lu)
# l1 loss
mseloss = nn.MSELoss()
# -----------------------resume training------------------------
if args.resume:
if os.path.isfile(args.resume):
print("--------loading checkpoint----------")
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_iter = checkpoint['iter']
glow_single.load_state_dict(checkpoint['state_dict'])
#optimizer.load_state_dict(checkpoint['optimizer'])
else:
print("--------no checkpoint found---------")
glow_single = glow_single.to(device)
glow = nn.DataParallel(glow_single)
glow.train()
# -------------------------------------------------------------
content_tf = train_transform()
style_tf = train_transform()
content_dataset = FlatFolderDataset(args.content_dir, content_tf)
style_dataset = FlatFolderDataset(args.style_dir, style_tf)
content_iter = iter(data.DataLoader(
content_dataset, batch_size=args.batch_size,
sampler=InfiniteSamplerWrapper(content_dataset),
num_workers=args.n_threads))
style_iter = iter(data.DataLoader(
style_dataset, batch_size=args.batch_size,
sampler=InfiniteSamplerWrapper(style_dataset),
num_workers=args.n_threads))
optimizer = torch.optim.Adam(glow.module.parameters(), lr=args.lr)
if args.resume:
if os.path.isfile(args.resume):
optimizer.load_state_dict(checkpoint['optimizer'])
log_c = []
log_s = []
log_mse = []
Time = time.time()
# -----------------------training------------------------
for i in range(args.start_iter, args.max_iter):
adjust_learning_rate(optimizer, iteration_count=i)
content_images = next(content_iter).to(device)
style_images = next(style_iter).to(device)
# glow forward: real -> z_real, style -> z_style
if i == args.start_iter:
with torch.no_grad():
_ = glow.module(content_images, forward=True)
continue
# (log_p, logdet, z_outs) = glow()
z_c = glow(content_images, forward=True)
z_s = glow(style_images, forward=True)
# reverse
stylized = glow(z_c, forward=False, style=z_s)
loss_c, loss_s = encoder(content_images, style_images, stylized)
loss_c = loss_c.mean()
loss_s = loss_s.mean()
loss_mse = mseloss(content_images, stylized)
loss_style = args.content_weight*loss_c + args.style_weight*loss_s + args.mse_weight*loss_mse
# optimizer update
optimizer.zero_grad()
loss_style.backward()
nn.utils.clip_grad_norm(glow.module.parameters(), 5)
optimizer.step()
# update loss log
log_c.append(loss_c.item())
log_s.append(loss_s.item())
log_mse.append(loss_mse.item())
# save image
if i % args.print_interval == 0:
with torch.no_grad():
# stylized ---> z ---> content
z_stylized = glow(stylized, forward=True)
real = glow(z_stylized, forward=False, style=z_c)
# pick another content
another_content = next(content_iter).to(device)
# stylized ---> z ---> another real
z_ac = glow(another_content, forward=True)
another_real = glow(z_stylized, forward=False, style=z_ac)
output_name = os.path.join(args.save_dir, "%06d.jpg" % i)
output_images = torch.cat((content_images.cpu(), style_images.cpu(), stylized.cpu(),
real.cpu(), another_content.cpu(), another_real.cpu()),
0)
save_image(output_images, output_name, nrow=args.batch_size)
print("iter %d time/iter: %.2f loss_c: %.3f loss_s: %.3f loss_mse: %.3f" % (i,
(time.time()-Time)/args.print_interval,
np.mean(np.array(log_c)), np.mean(np.array(log_s)),
np.mean(np.array(log_mse))
))
log_c = []
log_s = []
Time = time.time()
if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
state_dict = glow.module.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(torch.device('cpu'))
state = {'iter': i, 'state_dict': state_dict, 'optimizer': optimizer.state_dict()}
torch.save(state, args.resume)