Skip to content

Commit

Permalink
support lcm-lora
Browse files Browse the repository at this point in the history
  • Loading branch information
杨涛 committed Jan 16, 2024
1 parent 6b76ecf commit 2cd3769
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 31 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ _<sup>2</sup>[Department of Computing, The Hong Kong Polytechnic University](htt
<img src="samples/000004x2.gif" width="390px"/> <img src="samples/000080x2.gif" width="390px"/>

## News
(2024-1-16) Support LCM-LORA now. PASD can upscale a image with 2-4 steps. Try `python test_pasd.py --num_inference_steps 3 --use_lcm_lora`. You should download LCM-LORA model from [lcm-lora-sdv1-5](https://huggingface.co/latent-consistency/lcm-lora-sdv1-5) and put it in the `checkpoints` folder.

(2024-1-16) You may also want to check our new updates [SeeSR](https://github.com/cswry/seesr) and [Phantom](https://github.com/dreamoving/Phantom).

(2023-10-20) Add additional noise level via ```--added_noise_level``` and the SR result achieves a great balance between "extremely-detailed" and "over-smoothed". Very interesting!. You can control the SR's detail level freely.
Expand Down
4 changes: 2 additions & 2 deletions pipelines/pipeline_pasd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import TextualInversionLoaderMixin
from diffusers.loaders import TextualInversionLoaderMixin, LoraLoaderMixin
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
Expand Down Expand Up @@ -94,7 +94,7 @@
```
"""

class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
Expand Down
66 changes: 37 additions & 29 deletions test_pasd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, PNDMScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler#, StableDiffusionControlNetPipeline
from diffusers import AutoencoderKL, LCMScheduler, PNDMScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler#, StableDiffusionControlNetPipeline
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
Expand Down Expand Up @@ -90,6 +90,12 @@ def load_pasd_pipeline(args, accelerator, enable_xformers_memory_efficient_atten
#validation_pipeline.enable_vae_tiling()
validation_pipeline._init_tiled_vae(encoder_tile_size=args.encoder_tiled_size, decoder_tile_size=args.decoder_tiled_size)

if args.use_lcm_lora:
# load and fuse lcm lora
validation_pipeline.load_lora_weights(args.lcm_lora_path)
validation_pipeline.fuse_lora()
validation_pipeline.scheduler = LCMScheduler.from_config(validation_pipeline.scheduler.config)

return validation_pipeline

def load_high_level_net(args, device='cuda'):
Expand Down Expand Up @@ -242,33 +248,35 @@ def main(args, enable_xformers_memory_efficient_attention=True,):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_path", type=str, default="checkpoints/stable-diffusion-v1-5")
parser.add_argument("--pasd_model_path", type=str, default="runs/pasd/checkpoint-100000")
parser.add_argument("--personalized_model_path", type=str, default="majicmixRealistic_v6.safetensors") # toonyou_beta3.safetensors, majicmixRealistic_v6.safetensors, unet_disney
parser.add_argument("--control_type", choices=['realisr', 'grayscale'], nargs='?', default="realisr")
parser.add_argument('--high_level_info', choices=['classification', 'detection', 'caption'], nargs='?', default='')
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--added_prompt", type=str, default="clean, high-resolution, 8k")
parser.add_argument("--negative_prompt", type=str, default="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed")
parser.add_argument("--image_path", type=str, default="examples/RealSRSet")
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--mixed_precision", type=str, default="fp16") # no/fp16/bf16
parser.add_argument("--guidance_scale", type=float, default=7.5)
parser.add_argument("--conditioning_scale", type=float, default=1.0)
parser.add_argument("--blending_alpha", type=float, default=1.0)
parser.add_argument("--multiplier", type=float, default=0.6)
parser.add_argument("--num_inference_steps", type=int, default=20)
parser.add_argument("--process_size", type=int, default=768) # 512?
parser.add_argument("--decoder_tiled_size", type=int, default=224) # for 24G
parser.add_argument("--encoder_tiled_size", type=int, default=1024) # for 24G
parser.add_argument("--latent_tiled_size", type=int, default=320) # for 24G
parser.add_argument("--latent_tiled_overlap", type=int, default=8) # for 24G
parser.add_argument("--upscale", type=int, default=4)
parser.add_argument("--use_personalized_model", action="store_true")
parser.add_argument("--use_pasd_light", action="store_true")
parser.add_argument("--init_latent_with_noise", action="store_true")
parser.add_argument("--added_noise_level", type=int, default=400)
parser.add_argument("--offset_noise_scale", type=float, default=0.0)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--pretrained_model_path", type=str, default="checkpoints/stable-diffusion-v1-5", help="path of base SD model")
parser.add_argument("--lcm_lora_path", type=str, default="checkpoints/lcm-lora-sdv1-5", help="path of LCM lora model")
parser.add_argument("--pasd_model_path", type=str, default="runs/pasd/checkpoint-100000", help="path of PASD model")
parser.add_argument("--personalized_model_path", type=str, default="majicmixRealistic_v6.safetensors", help="name of personalized dreambooth model, path is 'checkpoints/personalized_models'") # toonyou_beta3.safetensors, majicmixRealistic_v6.safetensors, unet_disney
parser.add_argument("--control_type", choices=['realisr', 'grayscale'], nargs='?', default="realisr", help="task name")
parser.add_argument('--high_level_info', choices=['classification', 'detection', 'caption'], nargs='?', default='', help="high level information for prompt generation")
parser.add_argument("--prompt", type=str, default="", help="prompt for image generation")
parser.add_argument("--added_prompt", type=str, default="clean, high-resolution, 8k", help="additional prompt")
parser.add_argument("--negative_prompt", type=str, default="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", help="negative prompt")
parser.add_argument("--image_path", type=str, default="examples/RealSRSet", help="test image path or folder")
parser.add_argument("--output_dir", type=str, default="output", help="output folder")
parser.add_argument("--mixed_precision", type=str, default="fp16", help="mixed precision mode") # no/fp16/bf16
parser.add_argument("--guidance_scale", type=float, default=7.5, help="classifier-free guidance scale")
parser.add_argument("--conditioning_scale", type=float, default=1.0, help="conditioning scale for controlnet")
parser.add_argument("--blending_alpha", type=float, default=1.0, help="blending alpha for personalized model")
parser.add_argument("--multiplier", type=float, default=0.6, help="multiplier for personalized lora model")
parser.add_argument("--num_inference_steps", type=int, default=20, help="denoising steps")
parser.add_argument("--process_size", type=int, default=768, help="minimal input size for processing") # 512?
parser.add_argument("--decoder_tiled_size", type=int, default=224, help="decoder tile size for save GPU memory") # for 24G
parser.add_argument("--encoder_tiled_size", type=int, default=1024, help="encoder tile size for save GPU memory") # for 24G