Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds stable_cascade_2pass #253

Merged
merged 4 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix: clean up hires_fix resize logic; tests: 2pass tests
  • Loading branch information
tazlin committed May 12, 2024
commit 6a927172bc1a2f04e56ff2141cd9366b0dbc611c
55 changes: 37 additions & 18 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,25 +775,44 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis

# For hires fix, change the image sizes as we create an intermediate image first
if payload.get("hires_fix", False):
optimal_size = 512
model_details = None
if SharedModelManager.manager.compvis:
model_details = SharedModelManager.manager.compvis.get_model_reference_info(payload["model_name"])
model_details = (
SharedModelManager.manager.compvis.get_model_reference_info(payload["model_name"])
if SharedModelManager.manager.compvis
else None
)

original_width = pipeline_params.get("empty_latent_image.width")
original_height = pipeline_params.get("empty_latent_image.height")

if original_width is None or original_height is None:
logger.error("empty_latent_image.width or empty_latent_image.height not found. Using 512x512.")
original_width, original_height = (512, 512)

new_width, new_height = (None, None)

if model_details and model_details.get("baseline") == "stable_cascade":
optimal_size = 1024
width = pipeline_params.get("empty_latent_image.width", 0)
height = pipeline_params.get("empty_latent_image.height", 0)
recalculate_size = False
if width > optimal_size and height > optimal_size:
recalculate_size = True
elif optimal_size == 1024 and (width > optimal_size or height > optimal_size):
recalculate_size = True
if recalculate_size:
newwidth, newheight = ImageUtils.calculate_source_image_size(width, height, optimal_size)
pipeline_params["latent_upscale.width"] = width
pipeline_params["latent_upscale.height"] = height
pipeline_params["empty_latent_image.width"] = newwidth
pipeline_params["empty_latent_image.height"] = newheight
new_width, new_height = ImageUtils.get_first_pass_image_resolution_max(
original_width,
original_height,
)
else:
new_width, new_height = ImageUtils.get_first_pass_image_resolution_min(
original_width,
original_height,
)

# This is the *target* resolution
pipeline_params["latent_upscale.width"] = original_width
pipeline_params["latent_upscale.height"] = original_height

if new_width and new_height:
# This is the *first pass* resolution
pipeline_params["empty_latent_image.width"] = new_width
pipeline_params["empty_latent_image.height"] = new_height
else:
logger.error("Could not determine new image size for hires fix. Using 1024x1024.")
pipeline_params["empty_latent_image.width"] = 1024
pipeline_params["empty_latent_image.height"] = 1024

if payload.get("control_type"):
# Inject control net model manager
Expand Down
118 changes: 102 additions & 16 deletions hordelib/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,111 @@
from loguru import logger
from PIL import Image, ImageOps, PngImagePlugin, UnidentifiedImageError

IMAGE_CHUNK_SIZE = 64

DEFAULT_IMAGE_MIN_RESOLUTION = 512
DEFAULT_HIGHER_RES_MAX_RESOLUTION = 1024


class ImageUtils:

@classmethod
def calculate_source_image_size(cls, width, height, optimal_size=512):
final_width = width
final_height = height
# For SD 1.5 we don't want the image to be generated lower than 512
# So we only upscale, if the image had both parts higher than 512
if optimal_size == 512:
first_pass_ratio = min(final_height / optimal_size, final_width / optimal_size)
# For Stable Cascade, it can handle generating one part below 1024 well enough
# So to save generation time, we instead downgrade the initial image so that it's always
# a max of 1024 in any dimention
# then upscale it with the second pass
def resize_image_dimensions(
cls,
width: int,
height: int,
desired_dimension: int,
use_min: bool,
) -> tuple[int, int]:
"""Resize the image dimensions to have one side equal to the desired resolution, keeping the aspect ratio.

- If use_min is True, the side with the minimum length will be resized to the desired resolution.
- For example, if the image is 1024x2048 and the desired resolution is 512, the image will be
resized to 512x1024. (As desired for 512x trained models)
- If use_min is False, the side with the maximum length will be resized to the desired resolution.
- For example, if the image is 1024x2048 and the desired resolution is 1024, the image will be
resized to 512x1024. (As desired for 1024x trained models)
- If the image is smaller than the desired resolution, the image will not be resized.

Args:
width (int): The width of the image.
height (int): The height of the image.
desired_dimension (int): The desired single side resolution.
use_min (bool): Whether to use the minimum or maximum side.

Returns:
tuple[int, int]: The target first pass width and height of the image.
"""
if desired_dimension is None or desired_dimension <= 0:
raise ValueError("desired_resolution must be a positive integer.")

if width <= 0 or height <= 0:
raise ValueError("width and height must be positive integers.")

if width < desired_dimension and height < desired_dimension:
return width, height

if use_min:
ratio = min(
height / desired_dimension,
width / desired_dimension,
)
else:
first_pass_ratio = max(final_height / optimal_size, final_width / optimal_size)
width = (int(final_width / first_pass_ratio) // 64) * 64
height = (int(final_height / first_pass_ratio) // 64) * 64
return (width, height)
ratio = max(
height / desired_dimension,
width / desired_dimension,
)

new_width = int(width // (ratio * IMAGE_CHUNK_SIZE)) * IMAGE_CHUNK_SIZE
new_height = int(height // (ratio * IMAGE_CHUNK_SIZE)) * IMAGE_CHUNK_SIZE

return new_width, new_height

@classmethod
def get_first_pass_image_resolution_min(
cls,
width: int,
height: int,
min_dimension: int = DEFAULT_IMAGE_MIN_RESOLUTION,
):
"""Resize the image dimensions to have one side equal to the desired resolution, keeping the aspect ratio.

- If the image is larger than the desired resolution, the side with the minimum length will be resized to the
desired resolution.
- If the image is smaller than the desired resolution, the image will not be resized.

"""
if width > min_dimension and height > min_dimension:
return cls.resize_image_dimensions(
width,
height,
desired_dimension=min_dimension,
use_min=True,
)
return width, height

@classmethod
def get_first_pass_image_resolution_max(
cls,
width: int,
height: int,
max_dimension: int = DEFAULT_HIGHER_RES_MAX_RESOLUTION,
):
"""Resize the image dimensions to have one side equal to the desired resolution, keeping the aspect ratio.

- If the image is larger than the desired resolution, the side with the maximum length will be resized to the
desired resolution.
- If the image is smaller than the desired resolution, the image will not be resized.
"""

if max(width, height) > max_dimension:
return cls.resize_image_dimensions(
width,
height,
desired_dimension=max_dimension,
use_min=False,
)
return width, height

@classmethod
def add_image_alpha_channel(cls, source_image, alpha_image):
Expand Down Expand Up @@ -52,7 +138,7 @@ def resize_sources_to_request(cls, payload):
newwidth = payload["width"]
newheight = payload["height"]
if payload.get("hires_fix") or payload.get("control_type"):
newwidth, newheight = cls.calculate_source_image_size(payload["width"], payload["height"])
newwidth, newheight = cls.get_first_pass_image_resolution_min(payload["width"], payload["height"])
if source_image.size != (newwidth, newheight):
payload["source_image"] = source_image.resize(
(newwidth, newheight),
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 20 additions & 7 deletions tests/test_horde_inference_cascade.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# test_horde.py

import pytest
from PIL import Image

from hordelib.horde import HordeLib
Expand Down Expand Up @@ -320,7 +321,7 @@ def test_cascade_image_remix_triple(
pil_image,
)

def test_cascade_text_to_image_2pass(
def test_cascade_text_to_image_hires_2pass(
self,
hordelib_instance: HordeLib,
stable_cascade_base_model_name: str,
Expand Down Expand Up @@ -361,11 +362,23 @@ def test_cascade_text_to_image_2pass(
pil_image2 = hordelib_instance.basic_inference_single_image(data).image
assert pil_image2 is not None
assert isinstance(pil_image2, Image.Image)
assert not check_single_inference_image_similarity(
pil_image2,

img_filename_denoise_0 = "stable_cascade_text_to_image_2pass_denoise_0.png"
pil_image2.save(f"images/{img_filename_denoise_0}", quality=100)

assert pil_image2 is not None
assert isinstance(pil_image2, Image.Image)
with pytest.raises(AssertionError):
check_single_inference_image_similarity(
pil_image2,
pil_image,
exception_on_fail=True,
)
assert check_single_inference_image_similarity(
f"images_expected/{img_filename}",
pil_image,
)
# assert check_single_inference_image_similarity(
# f"images_expected/{img_filename}",
# pil_image,
# )
assert check_single_inference_image_similarity(
f"images_expected/{img_filename_denoise_0}",
pil_image2,
)
83 changes: 83 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
HistogramDistanceResultCode,
)
from hordelib.utils.gpuinfo import GPUInfo
from hordelib.utils.image_utils import ImageUtils


def test_worker_settings_singleton():
Expand Down Expand Up @@ -64,3 +65,85 @@ def test_gpuinfo_init(self):
assert info.vram_total[0] > 0
assert info.vram_free[0] > 0
assert info.vram_used[0] > 0


class TestImageUtils:
def test_get_first_pass_image_resolution_min(self):
expected = (512, 512)
calculated = ImageUtils.get_first_pass_image_resolution_min(512, 512)

assert calculated == expected

def test_under_sized_both_dimensions_min(self):
expected = (256, 256)
calculated = ImageUtils.get_first_pass_image_resolution_min(256, 256)

assert calculated == expected

def test_under_sized_one_dimension_min(self):
expected = (512, 256)
calculated = ImageUtils.get_first_pass_image_resolution_min(512, 256)

assert calculated == expected

def test_oversized_one_dimension_min(self):
expected = (1024, 512)
calculated = ImageUtils.get_first_pass_image_resolution_min(1024, 512)

assert calculated == expected

def test_oversized_other_dimension_min(self):
expected = (512, 1024)
calculated = ImageUtils.get_first_pass_image_resolution_min(512, 1024)

assert calculated == expected

def test_both_dimensions_oversized_evenly_min(self):
expected = (512, 512)
calculated = ImageUtils.get_first_pass_image_resolution_min(1024, 1024)

assert calculated == expected

def test_both_dimensions_oversized_unevenly_min(self):
expected = (512, 768)
calculated = ImageUtils.get_first_pass_image_resolution_min(1024, 1536)

assert calculated == expected

def test_get_first_pass_image_resolution_max(self):
expected = (1024, 1024)
calculated = ImageUtils.get_first_pass_image_resolution_max(1024, 1024)

assert calculated == expected

def test_under_sized_both_dimensions_max(self):
expected = (512, 512)
calculated = ImageUtils.get_first_pass_image_resolution_max(512, 512)

assert calculated == expected

def test_oversized_one_dimension_max(self):
expected = (1024, 512)
calculated = ImageUtils.get_first_pass_image_resolution_max(2048, 1024)

assert calculated == expected

def test_oversized_other_dimension_max(self):
expected = (512, 1024)
calculated = ImageUtils.get_first_pass_image_resolution_max(1024, 2048)

assert calculated == expected

def test_both_dimensions_oversized_evenly_max(self):
expected = (1024, 1024)
calculated = ImageUtils.get_first_pass_image_resolution_max(2048, 2048)

assert calculated == expected

def test_both_dimensions_oversized_unevenly_max(self):
expected = (640, 1024)
calculated_cascade = ImageUtils.get_first_pass_image_resolution_max(2048, 3072)
calculated_default = ImageUtils.get_first_pass_image_resolution_min(2048, 3072)

assert calculated_cascade != calculated_default
assert calculated_cascade == expected
Loading