Skip to content

Commit

Permalink
add examples text to image (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
slin000111 committed Jan 13, 2024
1 parent f2c9238 commit 163bbdd
Show file tree
Hide file tree
Showing 27 changed files with 5,790 additions and 1 deletion.
5 changes: 5 additions & 0 deletions examples/pytorch/sdxl/infer_text_image_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import infer_text_to_image_lora

if __name__ == '__main__':
infer_text_to_image_lora()
5 changes: 5 additions & 0 deletions examples/pytorch/sdxl/infer_text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import infer_text_to_image

if __name__ == '__main__':
infer_text_to_image()
5 changes: 5 additions & 0 deletions examples/pytorch/sdxl/infer_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import infer_text_to_image_lora_sdxl

if __name__ == '__main__':
infer_text_to_image_lora_sdxl()
5 changes: 5 additions & 0 deletions examples/pytorch/sdxl/infer_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.aigc import infer_text_to_image_sdxl

if __name__ == '__main__':
infer_text_to_image_sdxl()
8 changes: 8 additions & 0 deletions examples/pytorch/sdxl/scripts/run_infer_text_to_image.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python infer_text_to_image.py \
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-v1-5" \
--unet_model_path "train_text_to_image/checkpoint-15000/unet" \
--prompt "yoda" \
--image_save_path "yoda-pokemon.png" \
--torch_dtype "fp16" \
8 changes: 8 additions & 0 deletions examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python infer_text_to_image_lora.py \
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-v1-5" \
--lora_model_path "train_text_to_image_lora/checkpoint-80000" \
--prompt "A pokemon with green eyes and red legs." \
--image_save_path "lora_pokemon.png" \
--torch_dtype "fp16" \
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python infer_text_to_image_lora_sdxl.py \
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-xl-base-1.0" \
--lora_model_path "train_text_to_image_lora_sdxl/unet" \
--prompt "A pokemon with green eyes and red legs." \
--image_save_path "sdxl_lora_pokemon.png" \
--torch_dtype "fp16" \
8 changes: 8 additions & 0 deletions examples/pytorch/sdxl/scripts/run_infer_text_to_image_sdxl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python infer_text_to_image_sdxl.py \
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-xl-base-1.0" \
--unet_model_path "train_text_to_image_sdxl/checkpoint-10000/unet" \
--prompt "A pokemon with green eyes and red legs." \
--image_save_path "sdxl_pokemon.png" \
--torch_dtype "fp16" \
17 changes: 17 additions & 0 deletions examples/pytorch/sdxl/scripts/run_train_text_to_image.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
PYTHONPATH=../../../ \
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--use_ema \
--resolution=512 \
--center_crop \
--random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--output_dir="train_text_to_image" \
18 changes: 18 additions & 0 deletions examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
PYTHONPATH=../../../ \
accelerate launch train_text_to_image_lora.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--caption_column="text" \
--resolution=512 \
--random_flip \
--train_batch_size=1 \
--num_train_epochs=100 \
--checkpointing_steps=5000 \
--learning_rate=1e-04 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--seed=42 \
--output_dir="train_text_to_image_lora" \
--validation_prompt="cute dragon creature" \
--report_to="tensorboard" \
19 changes: 19 additions & 0 deletions examples/pytorch/sdxl/scripts/run_train_text_to_image_lora_sdxl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
PYTHONPATH=../../../ \
accelerate launch train_text_to_image_lora_sdxl.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \
--pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--caption_column="text" \
--resolution=1024 \
--random_flip \
--train_batch_size=1 \
--num_train_epochs=2 \
--checkpointing_steps=500 \
--learning_rate=1e-04 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--seed=42 \
--output_dir="train_text_to_image_lora_sdxl" \
--validation_prompt="cute dragon creature" \
--report_to="tensorboard" \
24 changes: 24 additions & 0 deletions examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
PYTHONPATH=../../../ \
accelerate launch train_text_to_image_sdxl.py \
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \
--pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \
--dataset_name="AI-ModelScope/pokemon-blip-captions" \
--enable_xformers_memory_efficient_attention \
--resolution=512 \
--center_crop \
--random_flip \
--proportion_empty_prompts=0.2 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=10000 \
--use_8bit_adam \
--learning_rate=1e-06 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--report_to="tensorboard" \
--validation_prompt="a cute Sundar Pichai creature" \
--validation_epochs 5 \
--checkpointing_steps=5000 \
--output_dir="train_text_to_image_sdxl" \
6 changes: 6 additions & 0 deletions examples/pytorch/sdxl/train_text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from swift.aigc import train_text_to_image

if __name__ == '__main__':
train_text_to_image()
6 changes: 6 additions & 0 deletions examples/pytorch/sdxl/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from swift.aigc import train_text_to_image_lora

if __name__ == '__main__':
train_text_to_image_lora()
6 changes: 6 additions & 0 deletions examples/pytorch/sdxl/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from swift.aigc import train_text_to_image_lora_sdxl

if __name__ == '__main__':
train_text_to_image_lora_sdxl()
6 changes: 6 additions & 0 deletions examples/pytorch/sdxl/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from swift.aigc import train_text_to_image_sdxl

if __name__ == '__main__':
train_text_to_image_sdxl()
2 changes: 1 addition & 1 deletion requirements/aigc.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decord
diffusers>=0.18.0
diffusers==0.25.0
einops
torchvision
9 changes: 9 additions & 0 deletions swift/aigc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
# Recommend using `xxx_main`
from .animatediff import animatediff_sft, animatediff_main
from .animatediff_infer import animatediff_infer, animatediff_infer_main
from .diffusers import train_text_to_image, train_text_to_image_lora, train_text_to_image_lora_sdxl, \
train_text_to_image_sdxl, infer_text_to_image, infer_text_to_image_lora, infer_text_to_image_sdxl, \
infer_text_to_image_lora_sdxl
from .utils import AnimateDiffArguments, AnimateDiffInferArguments
else:
_import_structure = {
'animatediff': ['animatediff_sft', 'animatediff_main'],
'animatediff_infer': ['animatediff_infer', 'animatediff_infer_main'],
'diffusers': [
'train_text_to_image', 'train_text_to_image_lora',
'train_text_to_image_lora_sdxl', 'train_text_to_image_sdxl',
'infer_text_to_image', 'infer_text_to_image_lora',
'infer_text_to_image_sdxl', 'infer_text_to_image_lora_sdxl'
],
'utils': ['AnimateDiffArguments', 'AnimateDiffInferArguments'],
}

Expand Down
10 changes: 10 additions & 0 deletions swift/aigc/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .infer_text_to_image import main as infer_text_to_image
from .infer_text_to_image_lora import main as infer_text_to_image_lora
from .infer_text_to_image_lora_sdxl import \
main as infer_text_to_image_lora_sdxl
from .infer_text_to_image_sdxl import main as infer_text_to_image_sdxl
from .train_text_to_image import main as train_text_to_image
from .train_text_to_image_lora import main as train_text_to_image_lora
from .train_text_to_image_lora_sdxl import \
main as train_text_to_image_lora_sdxl
from .train_text_to_image_sdxl import main as train_text_to_image_sdxl
111 changes: 111 additions & 0 deletions swift/aigc/diffusers/infer_text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os

import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
from modelscope import snapshot_download


def parse_args():
parser = argparse.ArgumentParser(
description='Simple example of a text to image inference.')
parser.add_argument(
'--pretrained_model_name_or_path',
type=str,
default='AI-ModelScope/stable-diffusion-v1-5',
required=True,
help=
'Path to pretrained model or model identifier from modelscope.cn/models.',
)
parser.add_argument(
'--revision',
type=str,
default=None,
required=False,
help=
'Revision of pretrained model identifier from modelscope.cn/models.',
)
parser.add_argument(
'--unet_model_path',
type=str,
default=None,
required=False,
help='The path to trained unet model.',
)
parser.add_argument(
'--prompt',
type=str,
default=None,
required=True,
help=
'The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`',
)
parser.add_argument(
'--image_save_path',
type=str,
default=None,
required=True,
help='The path to save generated image',
)
parser.add_argument(
'--torch_dtype',
type=str,
default=None,
choices=['no', 'fp16', 'bf16'],
help=
('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >='
' 1.10.and an Nvidia Ampere GPU. Default to the value of the'
' mixed_precision passed with the `accelerate.launch` command in training script.'
),
)
parser.add_argument(
'--num_inference_steps',
type=int,
default=50,
help=
('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \
expense of slower inference.'),
)
parser.add_argument(
'--guidance_scale',
type=float,
default=7.5,
choices=['no', 'fp16', 'bf16'],
help=
('A higher guidance scale value encourages the model to generate images closely linked to the text \
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.'
),
)

args = parser.parse_args()
return args


def main():
args = parse_args()

if os.path.exists(args.pretrained_model_name_or_path):
model_path = args.pretrained_model_name_or_path
else:
model_path = snapshot_download(
args.pretrained_model_name_or_path, revision=args.revision)

if args.torch_dtype == 'fp16':
torch_dtype = torch.float16
elif args.torch_dtype == 'bf16':
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32

pipe = StableDiffusionPipeline.from_pretrained(
model_path, torch_dtype=torch_dtype)
if args.unet_model_path is not None:
pipe.unet = UNet2DConditionModel.from_pretrained(
args.unet_model_path, torch_dtype=torch_dtype)
pipe.to('cuda')
image = pipe(
prompt=args.prompt,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale).images[0]
image.save(args.image_save_path)
Loading

0 comments on commit 163bbdd

Please sign in to comment.