From a4d16c504fbbe326d86a185c1fe69841a41fa46b Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Sat, 22 Jun 2024 16:42:02 -0700 Subject: [PATCH] refactor sync engine --- libs/infinity_emb/infinity_emb/sync_engine.py | 136 ++++++++---------- 1 file changed, 61 insertions(+), 75 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/sync_engine.py b/libs/infinity_emb/infinity_emb/sync_engine.py index 73c20feb..6951258b 100644 --- a/libs/infinity_emb/infinity_emb/sync_engine.py +++ b/libs/infinity_emb/infinity_emb/sync_engine.py @@ -1,12 +1,15 @@ import asyncio import threading -import time from concurrent.futures import Future -from typing import Iterator +from typing import TYPE_CHECKING, Awaitable, Callable, Iterator, TypeVar from infinity_emb.engine import AsyncEmbeddingEngine, AsyncEngineArray, EngineArgs from infinity_emb.log_handler import logger -from infinity_emb.primitives import ClassifyReturnType, EmbeddingDtype, ReRankReturnType + +if TYPE_CHECKING: + from infinity_emb import AsyncEmbeddingEngine + + # from infinity_emb.primitives import ClassifyReturnType, EmbeddingDtype, ReRankReturnType def add_start_docstrings(*docstr): @@ -17,45 +20,54 @@ def docstring_decorator(fn): return docstring_decorator -def threaded_asyncio_executor(): - def decorator(fn): - funcname = fn.__name__ # e.g. `embed` +T = TypeVar("T") - def wrapper(self: "SyncEngineArray", **kwargs) -> "Future": - future: Future = Future() - - assert self.is_running, "SyncEngineArray is not running" +class AsyncLifeMixin: + def __init__(self) -> None: + self._start_event: Future = Future() + self._stop_event = threading.Event() + self._is_closed: Future = Future() + threading.Thread(target=self._lifetime, daemon=True).start() + self._start_event.result() - def execute(): - async_function = getattr(self.async_engine_array, funcname) - try: - # async_function is e.g. `self.async_engine_array.embed` - # get async future object - result = asyncio.run_coroutine_threadsafe( - async_function(**kwargs), self._loop - ) - # block until the result is available - future.set_result(result.result()) - except Exception as e: - future.set_exception(e) + def _lifetime(self): + """takes care of starting, stopping event loop""" + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) - threading.Thread(target=execute).start() - return future # return the future object immediately + async def block_until_engine_stop(): + logger.info("Started Background Event Loop") + self._start_event.set_result(None) # signal that the event loop has started + while not self._stop_event.is_set(): + await asyncio.sleep(0.2) - wrapper.__doc__ = fn.__doc__ - return wrapper + self._loop.run_until_complete(block_until_engine_stop()) + self._loop.close() + self._is_closed.set_result(None) + logger.info("Closed Background Event Loop") - return decorator + def async_close_loop(self): + self._stop_event.set() + self._is_closed.result() + + def async_run( + self, async_function: Callable[..., Awaitable[T]], *args, **kwargs + ) -> Future[T]: + """blocks until the engine is running""" + if not self._loop.is_running() or self._stop_event.is_set(): + raise RuntimeError("Loop is not running") + future = asyncio.run_coroutine_threadsafe( + async_function(*args, **kwargs), self._loop + ) + return future @add_start_docstrings(AsyncEngineArray.__doc__) -class SyncEngineArray: +class SyncEngineArray(AsyncLifeMixin): def __init__(self, engine_args: list[EngineArgs]): - self._start_event = threading.Event() - self._stop_event = threading.Event() + super().__init__() self.async_engine_array = AsyncEngineArray.from_args(engine_args) - threading.Thread(target=self._lifetime).start() - self._start_event.wait() # wait until the event loop has started + self.async_run(self.async_engine_array.astart).result() @classmethod def from_args(cls, engine_args: list[EngineArgs]) -> "SyncEngineArray": @@ -63,64 +75,38 @@ def from_args(cls, engine_args: list[EngineArgs]) -> "SyncEngineArray": @property def is_running(self): - return ( - not self._stop_event.is_set() - and self._loop.is_running() - and self.async_engine_array.is_running - ) + return self.async_engine_array.is_running def __iter__(self) -> Iterator["AsyncEmbeddingEngine"]: return iter(self.async_engine_array) def stop(self): """blocks until the engine is stopped""" - self._stop_event.set() - while self._loop.is_running(): - time.sleep(0.05) - - def _lifetime(self): - """takes care of starting, stopping (engine and event loop)""" - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - async def block_until_engine_stop(): - logger.info("Started SyncEngineArray Background Event Loop") - self._start_event.set() # signal that the event loop has started - try: - await self.async_engine_array.astart() - while not self._stop_event.is_set(): - await asyncio.sleep(0.2) - finally: - await self.async_engine_array.astop() - # additional delay to ensure that the engine is stopped - await asyncio.sleep(2.0) - - self._loop.run_until_complete(block_until_engine_stop()) - self._loop.close() - logger.info("Closed SyncEngineArray Background Event Loop") + self.async_run(self.async_engine_array.astop).result() + self.async_close_loop() @add_start_docstrings(AsyncEngineArray.embed.__doc__) - @threaded_asyncio_executor() - def embed(self, *, model: str, sentences: list[str]) -> Future[EmbeddingDtype]: + def embed(self, *, model: str, sentences: list[str]): """sync interface of AsyncEngineArray""" - return None # type: ignore + return self.async_run( + self.async_engine_array.embed, model=model, sentences=sentences + ) @add_start_docstrings(AsyncEngineArray.rerank.__doc__) - @threaded_asyncio_executor() - def rerank( - self, *, model: str, query: str, docs: list[str] - ) -> Future[ReRankReturnType]: + def rerank(self, *, model: str, query: str, docs: list[str]): """sync interface of AsyncEngineArray""" - return None # type: ignore + return self.async_run( + self.async_engine_array.rerank, model=model, query=query, docs=docs + ) @add_start_docstrings(AsyncEngineArray.classify.__doc__) - @threaded_asyncio_executor() - def classify(self, *, model: str, text: str) -> Future[ClassifyReturnType]: + def classify(self, *, model: str, text: str): """sync interface of AsyncEngineArray""" - return None # type: ignore + return self.async_run(self.async_engine_array.classify, model=model, text=text) @add_start_docstrings(AsyncEngineArray.image_embed.__doc__) - @threaded_asyncio_executor() - def image_embed(self, *, model: str, images: list[str]) -> Future[EmbeddingDtype]: + def image_embed(self, *, model: str, images: list[str]): """sync interface of AsyncEngineArray""" - return None # type: ignore + return self.async_run( + self.async_engine_array.image_embed, model=model, images=images + )