Skip to content

Commit

Permalink
feat: add async support per original repo
Browse files Browse the repository at this point in the history
  • Loading branch information
Hongbin Huang committed Jun 6, 2023
1 parent 25c6d34 commit faf4f8c
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 27 deletions.
14 changes: 1 addition & 13 deletions easyai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,9 @@
InstructPix2PixInterface,
ModelKeywordInterface,
)
from .main import EasyAPI
from .main import EasyAI, EasyAPI, easyai
from .upscaler import HiResUpscaler, Upscaler


class EasyAI(EasyAPI):
def __init__(self):
super().__init__(
host="127.0.0.1",
port=80,
use_https=False,
)


easyai = EasyAI()

__version__ = "0.1.3"

__all__ = [
Expand Down
67 changes: 63 additions & 4 deletions easyai/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def __init__(
sampler="Euler a",
steps=20,
use_https=False,
username=None,
password=None,
):
if baseurl is None:
if use_https:
Expand All @@ -33,7 +35,10 @@ def __init__(

self.session = requests.Session()

self.check_controlnet()
if username and password:
self.set_auth(username, password)
else:
self.check_controlnet()

def check_controlnet(self):
try:
Expand All @@ -45,6 +50,27 @@ def check_controlnet(self):
def set_auth(self, username, password):
self.session.auth = (username, password)

def post_and_get_api_result(self, url, json, use_async):
if use_async:
import asyncio

return asyncio.ensure_future(self.async_post(url=url, json=json))
else:
response = self.session.post(url=url, json=json)
return self._to_api_result(response)

async def async_post(self, url, json):
import aiohttp

async with aiohttp.ClientSession() as session:
auth = (
aiohttp.BasicAuth(self.session.auth[0], self.session.auth[1])
if self.session.auth
else None
)
async with session.post(url, json=json, auth=auth) as response:
return await self._to_api_result_async(response)

def _to_api_result(self, response):
if response.status_code != 200:
raise RuntimeError(response.status_code, response.text)
Expand Down Expand Up @@ -73,6 +99,34 @@ def _to_api_result(self, response):

return APIResult(images, parameters, info)

async def _to_api_result_async(self, response):
if response.status != 200:
raise RuntimeError(response.status, await response.text)

r = await response.json()
images = []
if "images" in r.keys():
images = [Image.open(io.BytesIO(base64.b64decode(i))) for i in r["images"]]
elif "image" in r.keys():
images = [Image.open(io.BytesIO(base64.b64decode(r["image"])))]

info = ""
if "info" in r.keys():
try:
info = json.loads(r["info"])
except: # NOQA
info = r["info"]
elif "html_info" in r.keys():
info = r["html_info"]
elif "caption" in r.keys():
info = r["caption"]

parameters = ""
if "parameters" in r.keys():
parameters = r["parameters"]

return APIResult(images, parameters, info)

# XXX 500 error (2022/12/26)
def png_info(self, image):
payload = {
Expand Down Expand Up @@ -183,10 +237,15 @@ def custom_get(self, endpoint, baseurl=False):
response = self.session.get(url=url)
return response.json()

def custom_post(self, endpoint, payload={}, baseurl=False):
def custom_post(self, endpoint, payload={}, baseurl=False, use_async=False):
url = self.get_endpoint(endpoint, baseurl)
response = self.session.post(url=url, json=payload)
return self._to_api_result(response)
if use_async:
import asyncio

return asyncio.ensure_future(self.async_post(url=url, json=payload))
else:
response = self.session.post(url=url, json=payload)
return self._to_api_result(response)

def controlnet_version(self):
r = self.custom_get("controlnet/version")
Expand Down
36 changes: 26 additions & 10 deletions easyai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def txt2img(
controlnet_units: List[ControlNetUnit] = [],
sampler_index=None, # deprecated: use sampler_name
use_deprecated_controlnet=False,
use_async=False,
):
if sampler_index is None:
sampler_index = self.default_sampler
Expand Down Expand Up @@ -118,8 +119,9 @@ def txt2img(
# workaround : if not passed, webui will use previous args!
payload["alwayson_scripts"]["ControlNet"] = {"args": []}

response = self.session.post(url=f"{self.baseurl}/txt2img", json=payload)
return self._to_api_result(response)
return self.post_and_get_api_result(
f"{self.baseurl}/txt2img", payload, use_async
)

def img2img(
self,
Expand Down Expand Up @@ -169,6 +171,7 @@ def img2img(
alwayson_scripts={},
controlnet_units: List[ControlNetUnit] = [],
use_deprecated_controlnet=False,
use_async=False,
):
if sampler_name is None:
sampler_name = self.default_sampler
Expand Down Expand Up @@ -238,8 +241,9 @@ def img2img(
elif self.has_controlnet:
payload["alwayson_scripts"]["ControlNet"] = {"args": []}

response = self.session.post(url=f"{self.baseurl}/img2img", json=payload)
return self._to_api_result(response)
return self.post_and_get_api_result(
f"{self.baseurl}/img2img", payload, use_async
)

def extra_single_image(
self,
Expand All @@ -257,6 +261,7 @@ def extra_single_image(
upscaler_2="None",
extras_upscaler_2_visibility=0,
upscale_first=False,
use_async=False,
):
payload = {
"resize_mode": resize_mode,
Expand All @@ -275,10 +280,9 @@ def extra_single_image(
"image": b64_img(image),
}

response = self.session.post(
url=f"{self.baseurl}/extra-single-image", json=payload
return self.post_and_get_api_result(
f"{self.baseurl}/extra-single-image", payload, use_async
)
return self._to_api_result(response)

def extra_batch_images(
self,
Expand All @@ -297,6 +301,7 @@ def extra_batch_images(
upscaler_2="None",
extras_upscaler_2_visibility=0,
upscale_first=False,
use_async=False,
):
if name_list is not None:
if len(name_list) != len(images):
Expand Down Expand Up @@ -326,7 +331,18 @@ def extra_batch_images(
"imageList": image_list,
}

response = self.session.post(
url=f"{self.baseurl}/extra-batch-images", json=payload
return self.post_and_get_api_result(
f"{self.baseurl}/extra-batch-images", payload, use_async
)
return self._to_api_result(response)


class EasyAI(EasyAPI):
def __init__(self):
super().__init__(
host="127.0.0.1",
port=80,
use_https=False,
)


easyai = EasyAI()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ classifiers = [
requires = [
"requests",
"Pillow",
"aiohttp",
]
description-file = "README.md"
requires-python = ">=3.7"
Expand Down

0 comments on commit faf4f8c

Please sign in to comment.