Skip to content

Commit

Permalink
refactor sync engine
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Jun 22, 2024
1 parent 6e6d89c commit a4d16c5
Showing 1 changed file with 61 additions and 75 deletions.
136 changes: 61 additions & 75 deletions libs/infinity_emb/infinity_emb/sync_engine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
import threading
import time
from concurrent.futures import Future
from typing import Iterator
from typing import TYPE_CHECKING, Awaitable, Callable, Iterator, TypeVar

from infinity_emb.engine import AsyncEmbeddingEngine, AsyncEngineArray, EngineArgs
from infinity_emb.log_handler import logger
from infinity_emb.primitives import ClassifyReturnType, EmbeddingDtype, ReRankReturnType

if TYPE_CHECKING:
from infinity_emb import AsyncEmbeddingEngine

# from infinity_emb.primitives import ClassifyReturnType, EmbeddingDtype, ReRankReturnType


def add_start_docstrings(*docstr):
Expand All @@ -17,110 +20,93 @@ def docstring_decorator(fn):
return docstring_decorator


def threaded_asyncio_executor():
def decorator(fn):
funcname = fn.__name__ # e.g. `embed`
T = TypeVar("T")

def wrapper(self: "SyncEngineArray", **kwargs) -> "Future":
future: Future = Future()

assert self.is_running, "SyncEngineArray is not running"
class AsyncLifeMixin:
def __init__(self) -> None:
self._start_event: Future = Future()
self._stop_event = threading.Event()
self._is_closed: Future = Future()
threading.Thread(target=self._lifetime, daemon=True).start()
self._start_event.result()

def execute():
async_function = getattr(self.async_engine_array, funcname)
try:
# async_function is e.g. `self.async_engine_array.embed`
# get async future object
result = asyncio.run_coroutine_threadsafe(
async_function(**kwargs), self._loop
)
# block until the result is available
future.set_result(result.result())
except Exception as e:
future.set_exception(e)
def _lifetime(self):
"""takes care of starting, stopping event loop"""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)

threading.Thread(target=execute).start()
return future # return the future object immediately
async def block_until_engine_stop():
logger.info("Started Background Event Loop")
self._start_event.set_result(None) # signal that the event loop has started
while not self._stop_event.is_set():
await asyncio.sleep(0.2)

wrapper.__doc__ = fn.__doc__
return wrapper
self._loop.run_until_complete(block_until_engine_stop())
self._loop.close()
self._is_closed.set_result(None)
logger.info("Closed Background Event Loop")

return decorator
def async_close_loop(self):
self._stop_event.set()
self._is_closed.result()

def async_run(
self, async_function: Callable[..., Awaitable[T]], *args, **kwargs
) -> Future[T]:
"""blocks until the engine is running"""
if not self._loop.is_running() or self._stop_event.is_set():
raise RuntimeError("Loop is not running")
future = asyncio.run_coroutine_threadsafe(
async_function(*args, **kwargs), self._loop
)
return future


@add_start_docstrings(AsyncEngineArray.__doc__)
class SyncEngineArray:
class SyncEngineArray(AsyncLifeMixin):
def __init__(self, engine_args: list[EngineArgs]):
self._start_event = threading.Event()
self._stop_event = threading.Event()
super().__init__()
self.async_engine_array = AsyncEngineArray.from_args(engine_args)
threading.Thread(target=self._lifetime).start()
self._start_event.wait() # wait until the event loop has started
self.async_run(self.async_engine_array.astart).result()

@classmethod
def from_args(cls, engine_args: list[EngineArgs]) -> "SyncEngineArray":
return cls(engine_args)

@property
def is_running(self):
return (
not self._stop_event.is_set()
and self._loop.is_running()
and self.async_engine_array.is_running
)
return self.async_engine_array.is_running

def __iter__(self) -> Iterator["AsyncEmbeddingEngine"]:
return iter(self.async_engine_array)

def stop(self):
"""blocks until the engine is stopped"""
self._stop_event.set()
while self._loop.is_running():
time.sleep(0.05)

def _lifetime(self):
"""takes care of starting, stopping (engine and event loop)"""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)

async def block_until_engine_stop():
logger.info("Started SyncEngineArray Background Event Loop")
self._start_event.set() # signal that the event loop has started
try:
await self.async_engine_array.astart()
while not self._stop_event.is_set():
await asyncio.sleep(0.2)
finally:
await self.async_engine_array.astop()
# additional delay to ensure that the engine is stopped
await asyncio.sleep(2.0)

self._loop.run_until_complete(block_until_engine_stop())
self._loop.close()
logger.info("Closed SyncEngineArray Background Event Loop")
self.async_run(self.async_engine_array.astop).result()
self.async_close_loop()

@add_start_docstrings(AsyncEngineArray.embed.__doc__)
@threaded_asyncio_executor()
def embed(self, *, model: str, sentences: list[str]) -> Future[EmbeddingDtype]:
def embed(self, *, model: str, sentences: list[str]):
"""sync interface of AsyncEngineArray"""
return None # type: ignore
return self.async_run(
self.async_engine_array.embed, model=model, sentences=sentences
)

@add_start_docstrings(AsyncEngineArray.rerank.__doc__)
@threaded_asyncio_executor()
def rerank(
self, *, model: str, query: str, docs: list[str]
) -> Future[ReRankReturnType]:
def rerank(self, *, model: str, query: str, docs: list[str]):
"""sync interface of AsyncEngineArray"""
return None # type: ignore
return self.async_run(
self.async_engine_array.rerank, model=model, query=query, docs=docs
)

@add_start_docstrings(AsyncEngineArray.classify.__doc__)
@threaded_asyncio_executor()
def classify(self, *, model: str, text: str) -> Future[ClassifyReturnType]:
def classify(self, *, model: str, text: str):
"""sync interface of AsyncEngineArray"""
return None # type: ignore
return self.async_run(self.async_engine_array.classify, model=model, text=text)

@add_start_docstrings(AsyncEngineArray.image_embed.__doc__)
@threaded_asyncio_executor()
def image_embed(self, *, model: str, images: list[str]) -> Future[EmbeddingDtype]:
def image_embed(self, *, model: str, images: list[str]):
"""sync interface of AsyncEngineArray"""
return None # type: ignore
return self.async_run(
self.async_engine_array.image_embed, model=model, images=images
)

0 comments on commit a4d16c5

Please sign in to comment.