-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
58 lines (38 loc) · 2.95 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
import os
from src.app import utils
from src.config import settings
from src.modules import run_training, image_generate, video_generate
def main(args):
if args.training:
if args.version:
training_params = utils.get_params(os.path.join(settings.PATH_DATA, args.version, settings.JSON_TRAIN_PARAMS_FILENAME))
print(f'--Use {args.version} training params')
else:
training_params = utils.get_params(settings.PATH_TRAIN_PARAMS)
print('--Use main training params')
training_params['train_version'] = args.version if args.version else training_params['train_version']
path_dataset = os.path.join(settings.PATH_DATASET, training_params['path_dataset'])
run_training.main(training_params, settings.PATH_DATA, path_dataset, settings.PATH_TRAIN_PARAMS)
if args.image:
image_params = utils.get_params(settings.PATH_IMAGE_PARAMS)
image_params['train_version'] = args.version if args.version else image_params['train_version']
image_params['checkpoint_epoch'] = args.checkpoint_epoch if args.checkpoint_epoch else image_params['checkpoint_epoch']
training_params = utils.get_params(os.path.join(settings.PATH_DATA, image_params['train_version'], os.path.basename(settings.PATH_TRAIN_PARAMS)))
image_generate.main(training_params, image_params, settings.PATH_DATA, settings.PATH_IMAGES_GENERATED, upscale_width=args.upscale)
if args.video:
video_params = utils.get_params(settings.PATH_VIDEO_PARAMS)
video_params['train_version'] = args.version if args.version else video_params['train_version']
video_params['checkpoint_epoch'] = args.checkpoint_epoch if args.checkpoint_epoch else video_params['checkpoint_epoch']
training_params = utils.get_params(os.path.join(settings.PATH_DATA, video_params['train_version'], os.path.basename(settings.PATH_TRAIN_PARAMS)))
video_generate.main(training_params, video_params, settings.PATH_DATA, settings.PATH_VIDEOS_GENERATED, upscale_width=args.upscale)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Script to train the generator, validate the traning, generate images and/or videos from a trained generator")
parser.add_argument('--training', action='store_true', help='If true, executes the training')
parser.add_argument('--image', action='store_true', help='If true, generates images')
parser.add_argument('--video', action='store_true', help='If true, generates videos')
parser.add_argument('--upscale', type=int, default=None, help='Sets the upscale width. Can be None or an integer value.')
parser.add_argument('--version', type=str, default=None, help='Sets the version of training. Can be None or string value.')
parser.add_argument('--checkpoint-epoch', type=str, default=None, help='Sets the checkpoint epoch file of training. Can be None or string value.')
args = parser.parse_args()
main(args)