-
Notifications
You must be signed in to change notification settings - Fork 1
/
synthesize.py
147 lines (130 loc) · 6.13 KB
/
synthesize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# python3.7
"""A simple tool to synthesize images with pre-trained models."""
import os
import argparse
import subprocess
from tqdm import tqdm
import numpy as np
import torch
from models import MODEL_ZOO
from models import build_generator
from utils.misc import bool_parser
from utils.visualizer import HtmlPageVisualizer
from utils.visualizer import save_image
def postprocess(images):
"""Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
images = images.detach().cpu().numpy()
images = (images + 1) * 255 / 2
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
images = images.transpose(0, 2, 3, 1)
return images
def parse_args():
"""Parses arguments."""
parser = argparse.ArgumentParser(
description='Synthesize images with pre-trained models.')
parser.add_argument('model_name', type=str,
help='Name to the pre-trained model.')
parser.add_argument('--save_dir', type=str, default=None,
help='Directory to save the results. If not specified, '
'the results will be saved to '
'`work_dirs/synthesis/` by default. '
'(default: %(default)s)')
parser.add_argument('--num', type=int, default=100,
help='Number of samples to synthesize. '
'(default: %(default)s)')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size. (default: %(default)s)')
parser.add_argument('--generate_html', type=bool_parser, default=True,
help='Whether to use HTML page to visualize the '
'synthesized results. (default: %(default)s)')
parser.add_argument('--save_raw_synthesis', type=bool_parser, default=False,
help='Whether to save raw synthesis. '
'(default: %(default)s)')
parser.add_argument('--seed', type=int, default=0,
help='Seed for sampling. (default: %(default)s)')
parser.add_argument('--trunc_psi', type=float, default=0.7,
help='Psi factor used for truncation. This is '
'particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
parser.add_argument('--trunc_layers', type=int, default=8,
help='Number of layers to perform truncation. This is '
'particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
parser.add_argument('--randomize_noise', type=bool_parser, default=False,
help='Whether to randomize the layer-wise noise. This '
'is particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
if args.num <= 0:
return
if not args.save_raw_synthesis and not args.generate_html:
return
# Parse model configuration.
if args.model_name not in MODEL_ZOO:
raise SystemExit(f'Model `{args.model_name}` is not registered in '
f'`models/model_zoo.py`!')
model_config = MODEL_ZOO[args.model_name].copy()
url = model_config.pop('url') # URL to download model if needed.
# Get work directory and job name.
if args.save_dir:
work_dir = args.save_dir
else:
work_dir = os.path.join('work_dirs', 'synthesis')
os.makedirs(work_dir, exist_ok=True)
job_name = f'{args.model_name}_{args.num}'
if args.save_raw_synthesis:
os.makedirs(os.path.join(work_dir, job_name), exist_ok=True)
# Build generation and get synthesis kwargs.
print(f'Building generator for model `{args.model_name}` ...')
generator = build_generator(**model_config)
synthesis_kwargs = dict(trunc_psi=args.trunc_psi,
trunc_layers=args.trunc_layers,
randomize_noise=args.randomize_noise)
print(f'Finish building generator.')
# Load pre-trained weights.
os.makedirs('checkpoints', exist_ok=True)
checkpoint_path = os.path.join('checkpoints', args.model_name + '.pth')
print(f'Loading checkpoint from `{checkpoint_path}` ...')
if not os.path.exists(checkpoint_path):
print(f' Downloading checkpoint from `{url}` ...')
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
print(f' Finish downloading checkpoint.')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'generator_smooth' in checkpoint:
generator.load_state_dict(checkpoint['generator_smooth'])
else:
generator.load_state_dict(checkpoint['generator'])
generator = generator.cuda()
generator.eval()
print(f'Finish loading checkpoint.')
# Set random seed.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Sample and synthesize.
print(f'Synthesizing {args.num} samples ...')
indices = list(range(args.num))
if args.generate_html:
html = HtmlPageVisualizer(grid_size=args.num)
for batch_idx in tqdm(range(0, args.num, args.batch_size)):
sub_indices = indices[batch_idx:batch_idx + args.batch_size]
code = torch.randn(len(sub_indices), generator.z_space_dim).cuda()
with torch.no_grad():
images = generator(code, **synthesis_kwargs)['image']
images = postprocess(images)
for sub_idx, image in zip(sub_indices, images):
if args.save_raw_synthesis:
save_path = os.path.join(
work_dir, job_name, f'{sub_idx:06d}.jpg')
save_image(save_path, image)
if args.generate_html:
row_idx, col_idx = divmod(sub_idx, html.num_cols)
html.set_cell(row_idx, col_idx, image=image,
text=f'Sample {sub_idx:06d}')
if args.generate_html:
html.save(os.path.join(work_dir, f'{job_name}.html'))
print(f'Finish synthesizing {args.num} samples.')
if __name__ == '__main__':
main()