Skip to content

Commit

Permalink
fix bug & add offset noise
Browse files Browse the repository at this point in the history
  • Loading branch information
柏灌 committed Oct 18, 2023
1 parent f108d3c commit f1583f0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
8 changes: 7 additions & 1 deletion myutils/vaehook.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import torch.version
import torch.nn.functional as F
from einops import rearrange
from diffusers.utils.import_utils import is_xformers_available

import myutils.devices as devices
#from modules.shared import state
Expand Down Expand Up @@ -362,7 +363,12 @@ def attn2task(task_queue, net):
else:
task_queue.append(('store_res', lambda x: x))
task_queue.append(('pre_norm', net.group_norm))
task_queue.append(('attn', lambda x, net=net: attn_forward_new_xformers(net, x)))
if is_xformers_available:
task_queue.append(('attn', lambda x, net=net: attn_forward_new_xformers(net, x)))
elif hasattr(F, "scaled_dot_product_attention"):
task_queue.append(('attn', lambda x, net=net: attn_forward_new_pt2_0(net, x)))
else:
task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
task_queue.append(['add_res', None])

def resblock2task(queue, block):
Expand Down
11 changes: 6 additions & 5 deletions pipelines/pipeline_pasd.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def prepare_image(
return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
def prepare_latents(self, args, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
Expand All @@ -709,9 +709,9 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
#latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
#offset_noise = torch.randn(batch_size, num_channels_latents, 1, 1, device=device)
#latents = latents + 0.1 * offset_noise
offset_noise = torch.randn(batch_size, num_channels_latents, 1, 1, device=device).to(dtype)
offset_noise_scale = args.offset_noise_scale if args is not None else 0.1
latents = latents + offset_noise_scale * offset_noise
else:
latents = latents.to(device)

Expand Down Expand Up @@ -961,6 +961,7 @@ def __call__(
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
args,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
Expand Down Expand Up @@ -992,7 +993,7 @@ def __call__(
controlnet_prompt_embeds = prompt_embeds

_, _, h, w = latent_model_input.size()
tile_size, tile_overlap = args.latent_tiled_size, args.latent_tiled_overlap if args is not None else 256, 8
tile_size, tile_overlap = (args.latent_tiled_size, args.latent_tiled_overlap) if args is not None else (256, 8)
if h*w<=tile_size*tile_size: #h<tile_size and w<tile_size: # tiled latent input
down_block_res_samples, mid_block_res_sample = [None]*10, None
rgbs, down_block_res_samples, mid_block_res_sample = self.controlnet(
Expand Down
11 changes: 7 additions & 4 deletions test_pasd.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def main(args, enable_xformers_memory_efficient_attention=True,):

for image_name in image_names[:]:
validation_image = Image.open(image_name).convert("RGB")
#validation_image = Image.new(mode='RGB', size=validation_image.size, color=(0,0,0))
if args.control_type == "realisr":
validation_prompt = get_validation_prompt(args, validation_image, model, preprocess, category)
validation_prompt += args.added_prompt # clean, extremely detailed, best quality, sharp, clean
Expand Down Expand Up @@ -210,14 +211,15 @@ def main(args, enable_xformers_memory_efficient_attention=True,):
#width, height = validation_image.size
resize_flag = True #

try:
#try:
if True:
image = pipeline(
args, validation_prompt, validation_image, num_inference_steps=args.num_inference_steps, generator=generator, #height=height, width=width,
guidance_scale=args.guidance_scale, negative_prompt=negative_prompt, conditioning_scale=args.conditioning_scale,
).images[0]
except Exception as e:
print(e)
continue
#except Exception as e:
# print(e)
# continue

if True: #args.conditioning_scale < 1.0:
image = wavelet_color_fix(image, validation_image)
Expand Down Expand Up @@ -261,6 +263,7 @@ def main(args, enable_xformers_memory_efficient_attention=True,):
parser.add_argument("--vae_tiled_size", type=int, default=224) # for 24G
parser.add_argument("--latent_tiled_size", type=int, default=320) # for 24G
parser.add_argument("--latent_tiled_overlap", type=int, default=8) # for 24G
parser.add_argument("--offset_noise_scale", type=float, default=0.1)
parser.add_argument("--upscale", type=int, default=4)
parser.add_argument("--use_personalized_model", action="store_true")
parser.add_argument("--use_pasd_light", action="store_true")
Expand Down

0 comments on commit f1583f0

Please sign in to comment.