Skip to content

Commit

Permalink
initialize latent with lr image
Browse files Browse the repository at this point in the history
  • Loading branch information
柏灌 committed Oct 18, 2023
1 parent fec4f26 commit 630e99a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ _<sup>2</sup>[Department of Computing, The Hong Kong Polytechnic University](htt
<img src="samples/000004x2.gif" width="390px"/> <img src="samples/000080x2.gif" width="390px"/>

## News
(2023-10-18) Completely solved the [issues](https://github.com/yangxy/PASD/issues/16) by initializing latents with input LR images. Interestingly, the SR results also become much more stable.

(2023-10-11) [Colab demo](https://colab.research.google.com/drive/1lZ_-rSGcmreLCiRniVT973x6JLjFiC-b?usp=sharing) is now available. Credits to [Masahide Okada](https://github.com/MasahideOkada).

(2023-10-09) Add training dataset.

(2023-09-28) Add tiled latent to allow upscaling ultra high-resolution images. Please carefully set ```tiled_size``` in ```pipelines/pipeline_pasd.py``` as well as ```--vae_tiled_size``` when upscaling large images.
(2023-09-28) Add tiled latent to allow upscaling ultra high-resolution images. Please carefully set ```latent_tiled_size``` as well as ```--decoder_tiled_size``` when upscaling large images.

(2023-09-12) Add Gradio demo.

Expand Down
30 changes: 22 additions & 8 deletions pipelines/pipeline_pasd.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
self.register_to_config(requires_safety_checker=requires_safety_checker)

def _init_tiled_vae(self,
encoder_tile_size = 256,
encoder_tile_size = 1024,
decoder_tile_size = 256,
fast_decoder = False,
fast_encoder = False,
Expand Down Expand Up @@ -698,21 +698,34 @@ def prepare_image(
return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, args, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
def prepare_latents(self, args, image, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
offset_noise = torch.randn(batch_size, num_channels_latents, 1, 1, device=device).to(dtype)
offset_noise_scale = args.offset_noise_scale if args is not None else 0.05
latents = latents + offset_noise_scale * offset_noise
if args is None or args.init_latent_with_noise:
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
offset_noise = torch.randn(batch_size, num_channels_latents, 1, 1, device=device).to(dtype)
offset_noise_scale = args.offset_noise_scale if args is not None else 0.0
latents = latents + offset_noise_scale * offset_noise
else:
latents = latents.to(device)
else:
latents = latents.to(device)
#print(image.shape, image.min(), image.max())
if dtype==torch.float16: self.vae.quant_conv.half()
init_latents = self.vae.encode(image*2.0-1.0).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
self.scheduler.set_timesteps(args.num_inference_steps, device=device)
timesteps = self.scheduler.timesteps[0:]
latent_timestep = timesteps[:1].repeat(batch_size * 1)
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep)
latents = init_latents

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
Expand Down Expand Up @@ -962,6 +975,7 @@ def __call__(
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
args,
image[:1],
batch_size * num_images_per_prompt,
num_channels_latents,
height,
Expand Down
25 changes: 13 additions & 12 deletions test_pasd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, PNDMScheduler, UniPCMultistepScheduler#, StableDiffusionControlNetPipeline
from diffusers import AutoencoderKL, PNDMScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler#, StableDiffusionControlNetPipeline
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
Expand Down Expand Up @@ -88,7 +88,7 @@ def load_pasd_pipeline(args, accelerator, enable_xformers_memory_efficient_atten
unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
)
#validation_pipeline.enable_vae_tiling()
validation_pipeline._init_tiled_vae(decoder_tile_size=args.vae_tiled_size)
validation_pipeline._init_tiled_vae(encoder_tile_size=args.encoder_tiled_size, decoder_tile_size=args.decoder_tiled_size)

return validation_pipeline

Expand Down Expand Up @@ -211,15 +211,14 @@ def main(args, enable_xformers_memory_efficient_attention=True,):
#width, height = validation_image.size
resize_flag = True #

#try:
if True:
try:
image = pipeline(
args, validation_prompt, validation_image, num_inference_steps=args.num_inference_steps, generator=generator, #height=height, width=width,
guidance_scale=args.guidance_scale, negative_prompt=negative_prompt, conditioning_scale=args.conditioning_scale,
).images[0]
#except Exception as e:
# print(e)
# continue
except Exception as e:
print(e)
continue

if True: #args.conditioning_scale < 1.0:
image = wavelet_color_fix(image, validation_image)
Expand Down Expand Up @@ -250,23 +249,25 @@ def main(args, enable_xformers_memory_efficient_attention=True,):
parser.add_argument('--high_level_info', choices=['classification', 'detection', 'caption'], nargs='?', default='')
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--added_prompt", type=str, default="clean, high-resolution, 8k")
parser.add_argument("--negative_prompt", type=str, default="raster lines, dotted, noise, blurry, unclear, lowres, over-smoothed")
parser.add_argument("--negative_prompt", type=str, default="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed")
parser.add_argument("--image_path", type=str, default="examples/RealSRSet")
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--mixed_precision", type=str, default="fp16") # no/fp16/bf16
parser.add_argument("--guidance_scale", type=float, default=7.5)
parser.add_argument("--conditioning_scale", type=float, default=1.0)
parser.add_argument("--blending_alpha", type=float, default=1.0)
parser.add_argument("--multiplier", type=float, default=0.6)
parser.add_argument("--num_inference_steps", type=int, default=16)
parser.add_argument("--process_size", type=int, default=768)
parser.add_argument("--vae_tiled_size", type=int, default=224) # for 24G
parser.add_argument("--num_inference_steps", type=int, default=20)
parser.add_argument("--process_size", type=int, default=768) # 512?
parser.add_argument("--decoder_tiled_size", type=int, default=224) # for 24G
parser.add_argument("--encoder_tiled_size", type=int, default=1024) # for 24G
parser.add_argument("--latent_tiled_size", type=int, default=320) # for 24G
parser.add_argument("--latent_tiled_overlap", type=int, default=8) # for 24G
parser.add_argument("--offset_noise_scale", type=float, default=0.05)
parser.add_argument("--upscale", type=int, default=4)
parser.add_argument("--use_personalized_model", action="store_true")
parser.add_argument("--use_pasd_light", action="store_true")
parser.add_argument("--init_latent_with_noise", action="store_true")
parser.add_argument("--offset_noise_scale", type=float, default=0.0)
parser.add_argument("--seed", type=int, default=None)
args = parser.parse_args()
main(args)

0 comments on commit 630e99a

Please sign in to comment.