import argparse import os import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from os.path import basename from os.path import splitext from torchvision import transforms from torchvision.utils import save_image from pathlib import Path import time import numpy as np import random def test_transform(img, size): transform_list = [] h, w, _ = np.shape(img) if h loaded checkpoint '{}'".format(args.decoder)) else: print("--------no checkpoint found---------") glow = glow.to(device) glow.eval() # -----------------------start------------------------ for content_path in content_paths: for style_path in style_paths: with torch.no_grad(): content = Image.open(str(content_path)).convert('RGB') img_transform = test_transform(content, args.size) content = img_transform(content) content = content.to(device).unsqueeze(0) style = Image.open(str(style_path)).convert('RGB') img_transform = test_transform(style, args.size) style = img_transform(style) style = style.to(device).unsqueeze(0) # content/style ---> z ---> stylized z_c = glow(content, forward=True) z_s = glow(style, forward=True) output = glow(z_c, forward=False, style=z_s) output = output.cpu() output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format( content_path.stem, style_path.stem, args.save_ext) print(output_name) save_image(output, str(output_name))