-
Notifications
You must be signed in to change notification settings - Fork 213
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f2c9238
commit 163bbdd
Showing
27 changed files
with
5,790 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
from swift.aigc import infer_text_to_image_lora | ||
|
||
if __name__ == '__main__': | ||
infer_text_to_image_lora() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
from swift.aigc import infer_text_to_image | ||
|
||
if __name__ == '__main__': | ||
infer_text_to_image() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
from swift.aigc import infer_text_to_image_lora_sdxl | ||
|
||
if __name__ == '__main__': | ||
infer_text_to_image_lora_sdxl() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
from swift.aigc import infer_text_to_image_sdxl | ||
|
||
if __name__ == '__main__': | ||
infer_text_to_image_sdxl() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
PYTHONPATH=../../.. \ | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
python infer_text_to_image.py \ | ||
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-v1-5" \ | ||
--unet_model_path "train_text_to_image/checkpoint-15000/unet" \ | ||
--prompt "yoda" \ | ||
--image_save_path "yoda-pokemon.png" \ | ||
--torch_dtype "fp16" \ |
8 changes: 8 additions & 0 deletions
8
examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
PYTHONPATH=../../.. \ | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
python infer_text_to_image_lora.py \ | ||
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-v1-5" \ | ||
--lora_model_path "train_text_to_image_lora/checkpoint-80000" \ | ||
--prompt "A pokemon with green eyes and red legs." \ | ||
--image_save_path "lora_pokemon.png" \ | ||
--torch_dtype "fp16" \ |
8 changes: 8 additions & 0 deletions
8
examples/pytorch/sdxl/scripts/run_infer_text_to_image_lora_sdxl.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
PYTHONPATH=../../.. \ | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
python infer_text_to_image_lora_sdxl.py \ | ||
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-xl-base-1.0" \ | ||
--lora_model_path "train_text_to_image_lora_sdxl/unet" \ | ||
--prompt "A pokemon with green eyes and red legs." \ | ||
--image_save_path "sdxl_lora_pokemon.png" \ | ||
--torch_dtype "fp16" \ |
8 changes: 8 additions & 0 deletions
8
examples/pytorch/sdxl/scripts/run_infer_text_to_image_sdxl.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
PYTHONPATH=../../.. \ | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
python infer_text_to_image_sdxl.py \ | ||
--pretrained_model_name_or_path "AI-ModelScope/stable-diffusion-xl-base-1.0" \ | ||
--unet_model_path "train_text_to_image_sdxl/checkpoint-10000/unet" \ | ||
--prompt "A pokemon with green eyes and red legs." \ | ||
--image_save_path "sdxl_pokemon.png" \ | ||
--torch_dtype "fp16" \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
PYTHONPATH=../../../ \ | ||
accelerate launch --mixed_precision="fp16" train_text_to_image.py \ | ||
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \ | ||
--dataset_name="AI-ModelScope/pokemon-blip-captions" \ | ||
--use_ema \ | ||
--resolution=512 \ | ||
--center_crop \ | ||
--random_flip \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--gradient_checkpointing \ | ||
--max_train_steps=15000 \ | ||
--learning_rate=1e-05 \ | ||
--max_grad_norm=1 \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--output_dir="train_text_to_image" \ |
18 changes: 18 additions & 0 deletions
18
examples/pytorch/sdxl/scripts/run_train_text_to_image_lora.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
PYTHONPATH=../../../ \ | ||
accelerate launch train_text_to_image_lora.py \ | ||
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-v1-5" \ | ||
--dataset_name="AI-ModelScope/pokemon-blip-captions" \ | ||
--caption_column="text" \ | ||
--resolution=512 \ | ||
--random_flip \ | ||
--train_batch_size=1 \ | ||
--num_train_epochs=100 \ | ||
--checkpointing_steps=5000 \ | ||
--learning_rate=1e-04 \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--mixed_precision="fp16" \ | ||
--seed=42 \ | ||
--output_dir="train_text_to_image_lora" \ | ||
--validation_prompt="cute dragon creature" \ | ||
--report_to="tensorboard" \ |
19 changes: 19 additions & 0 deletions
19
examples/pytorch/sdxl/scripts/run_train_text_to_image_lora_sdxl.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
PYTHONPATH=../../../ \ | ||
accelerate launch train_text_to_image_lora_sdxl.py \ | ||
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \ | ||
--pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \ | ||
--dataset_name="AI-ModelScope/pokemon-blip-captions" \ | ||
--caption_column="text" \ | ||
--resolution=1024 \ | ||
--random_flip \ | ||
--train_batch_size=1 \ | ||
--num_train_epochs=2 \ | ||
--checkpointing_steps=500 \ | ||
--learning_rate=1e-04 \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--mixed_precision="fp16" \ | ||
--seed=42 \ | ||
--output_dir="train_text_to_image_lora_sdxl" \ | ||
--validation_prompt="cute dragon creature" \ | ||
--report_to="tensorboard" \ |
24 changes: 24 additions & 0 deletions
24
examples/pytorch/sdxl/scripts/run_train_text_to_image_sdxl.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
PYTHONPATH=../../../ \ | ||
accelerate launch train_text_to_image_sdxl.py \ | ||
--pretrained_model_name_or_path="AI-ModelScope/stable-diffusion-xl-base-1.0" \ | ||
--pretrained_vae_model_name_or_path="AI-ModelScope/sdxl-vae-fp16-fix" \ | ||
--dataset_name="AI-ModelScope/pokemon-blip-captions" \ | ||
--enable_xformers_memory_efficient_attention \ | ||
--resolution=512 \ | ||
--center_crop \ | ||
--random_flip \ | ||
--proportion_empty_prompts=0.2 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--gradient_checkpointing \ | ||
--max_train_steps=10000 \ | ||
--use_8bit_adam \ | ||
--learning_rate=1e-06 \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--mixed_precision="fp16" \ | ||
--report_to="tensorboard" \ | ||
--validation_prompt="a cute Sundar Pichai creature" \ | ||
--validation_epochs 5 \ | ||
--checkpointing_steps=5000 \ | ||
--output_dir="train_text_to_image_sdxl" \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
|
||
from swift.aigc import train_text_to_image | ||
|
||
if __name__ == '__main__': | ||
train_text_to_image() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
|
||
from swift.aigc import train_text_to_image_lora | ||
|
||
if __name__ == '__main__': | ||
train_text_to_image_lora() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
|
||
from swift.aigc import train_text_to_image_lora_sdxl | ||
|
||
if __name__ == '__main__': | ||
train_text_to_image_lora_sdxl() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
|
||
from swift.aigc import train_text_to_image_sdxl | ||
|
||
if __name__ == '__main__': | ||
train_text_to_image_sdxl() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
decord | ||
diffusers>=0.18.0 | ||
diffusers==0.25.0 | ||
einops | ||
torchvision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from .infer_text_to_image import main as infer_text_to_image | ||
from .infer_text_to_image_lora import main as infer_text_to_image_lora | ||
from .infer_text_to_image_lora_sdxl import \ | ||
main as infer_text_to_image_lora_sdxl | ||
from .infer_text_to_image_sdxl import main as infer_text_to_image_sdxl | ||
from .train_text_to_image import main as train_text_to_image | ||
from .train_text_to_image_lora import main as train_text_to_image_lora | ||
from .train_text_to_image_lora_sdxl import \ | ||
main as train_text_to_image_lora_sdxl | ||
from .train_text_to_image_sdxl import main as train_text_to_image_sdxl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import argparse | ||
import os | ||
|
||
import torch | ||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel | ||
from modelscope import snapshot_download | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='Simple example of a text to image inference.') | ||
parser.add_argument( | ||
'--pretrained_model_name_or_path', | ||
type=str, | ||
default='AI-ModelScope/stable-diffusion-v1-5', | ||
required=True, | ||
help= | ||
'Path to pretrained model or model identifier from modelscope.cn/models.', | ||
) | ||
parser.add_argument( | ||
'--revision', | ||
type=str, | ||
default=None, | ||
required=False, | ||
help= | ||
'Revision of pretrained model identifier from modelscope.cn/models.', | ||
) | ||
parser.add_argument( | ||
'--unet_model_path', | ||
type=str, | ||
default=None, | ||
required=False, | ||
help='The path to trained unet model.', | ||
) | ||
parser.add_argument( | ||
'--prompt', | ||
type=str, | ||
default=None, | ||
required=True, | ||
help= | ||
'The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`', | ||
) | ||
parser.add_argument( | ||
'--image_save_path', | ||
type=str, | ||
default=None, | ||
required=True, | ||
help='The path to save generated image', | ||
) | ||
parser.add_argument( | ||
'--torch_dtype', | ||
type=str, | ||
default=None, | ||
choices=['no', 'fp16', 'bf16'], | ||
help= | ||
('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=' | ||
' 1.10.and an Nvidia Ampere GPU. Default to the value of the' | ||
' mixed_precision passed with the `accelerate.launch` command in training script.' | ||
), | ||
) | ||
parser.add_argument( | ||
'--num_inference_steps', | ||
type=int, | ||
default=50, | ||
help= | ||
('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \ | ||
expense of slower inference.'), | ||
) | ||
parser.add_argument( | ||
'--guidance_scale', | ||
type=float, | ||
default=7.5, | ||
choices=['no', 'fp16', 'bf16'], | ||
help= | ||
('A higher guidance scale value encourages the model to generate images closely linked to the text \ | ||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.' | ||
), | ||
) | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
if os.path.exists(args.pretrained_model_name_or_path): | ||
model_path = args.pretrained_model_name_or_path | ||
else: | ||
model_path = snapshot_download( | ||
args.pretrained_model_name_or_path, revision=args.revision) | ||
|
||
if args.torch_dtype == 'fp16': | ||
torch_dtype = torch.float16 | ||
elif args.torch_dtype == 'bf16': | ||
torch_dtype = torch.bfloat16 | ||
else: | ||
torch_dtype = torch.float32 | ||
|
||
pipe = StableDiffusionPipeline.from_pretrained( | ||
model_path, torch_dtype=torch_dtype) | ||
if args.unet_model_path is not None: | ||
pipe.unet = UNet2DConditionModel.from_pretrained( | ||
args.unet_model_path, torch_dtype=torch_dtype) | ||
pipe.to('cuda') | ||
image = pipe( | ||
prompt=args.prompt, | ||
num_inference_steps=args.num_inference_steps, | ||
guidance_scale=args.guidance_scale).images[0] | ||
image.save(args.image_save_path) |
Oops, something went wrong.