-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_distilled_lcm_hpsprompt.py
83 lines (66 loc) · 3.86 KB
/
generate_distilled_lcm_hpsprompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# ------------------------------------------------------------------------------------
# Copyright 2023–2024 Nota Inc. All Rights Reserved.
# ------------------------------------------------------------------------------------
import os
import argparse
import time
from utils.inference_pipeline import InferencePipeline
from utils.misc import get_file_list_from_csv, change_img_size
from diffusers import LCMScheduler
import hpsv2
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default="lykon/absolutereality")
parser.add_argument("--save_dir", type=str, default="./inference_results/photo",
help="$save_dir/{im256, im512} are created for saving 256x256 and 512x512 images")
parser.add_argument("--data_list", type=str, default="./data/mscoco_val2014_30k/metadata.csv")
parser.add_argument("--num_images", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=4)
parser.add_argument('--device', type=str, default='cuda:0', help='Device to use, cuda:gpu_number or cpu')
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--img_sz", type=int, default=512)
parser.add_argument("--img_resz", type=int, default=256)
parser.add_argument("--batch_sz", type=int, default=1)
parser.add_argument("--unet_path", type=str, default="./results/absreality-lcm", required=False, help='path to the unet model')
parser.add_argument("--scheduler", type=str, default="lcm", required=False, help='type of scheduler')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
pipeline = InferencePipeline(weight_folder = args.model_id,
seed = args.seed,
device = args.device)
pipeline.set_pipe_and_generator()
if args.unet_path is not None: # use a separate trained unet for generation
from diffusers import UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained(args.unet_path, subfolder='unet')
pipeline.pipe.unet = unet.half().to(args.device)
print(f"** load unet from {args.unet_path}")
if args.scheduler == 'lcm':
pipeline.pipe.scheduler = LCMScheduler.from_config(pipeline.pipe.scheduler.config)
save_dir_src = os.path.join(args.save_dir, f'im{args.img_sz}') # for model's raw output images
os.makedirs(save_dir_src, exist_ok=True)
save_dir_tgt = os.path.join(args.save_dir, f'im{args.img_resz}') # for resized images for ms-coco benchmark
os.makedirs(save_dir_tgt, exist_ok=True)
# file_list = get_file_list_from_csv(args.data_list)
file_list = hpsv2.benchmark_prompts('photo')
params_str = pipeline.get_sdm_params()
t0 = time.perf_counter()
for batch_start in range(0, len(file_list), args.batch_sz):
batch_end = batch_start + args.batch_sz
val_prompts = file_list[batch_start: batch_end]
# img_names = [file_info[0] for file_info in file_list[batch_start: batch_end]]
# val_prompts = [file_info[1] for file_info in file_list[batch_start: batch_end]]
imgs = pipeline.generate(prompt = val_prompts,
n_steps = args.num_inference_steps,
img_sz = args.img_sz)
for i, (img, val_prompt) in enumerate(zip(imgs, val_prompts)):
img_name = f"{batch_start:05d}.jpg"
img.save(os.path.join(save_dir_src, img_name))
img.close()
print(f"{batch_start + i}/{len(file_list)} | {img_name} {val_prompt}")
print(f"---{params_str}")
pipeline.clear()
change_img_size(save_dir_src, save_dir_tgt, args.img_resz)
print(f"{(time.perf_counter()-t0):.2f} sec elapsed")
hpsv2.evaluate("inference_results", hps_version="v2.1")