Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: 新增 Lifespan.on_ready() 供适配器使用 #2483

Merged
merged 10 commits into from
Dec 10, 2023
12 changes: 0 additions & 12 deletions nonebot/drivers/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup

from ._lifespan import LIFESPAN_FUNC, Lifespan

try:
import uvicorn
from fastapi.responses import Response
Expand Down Expand Up @@ -97,8 +95,6 @@ def __init__(self, env: Env, config: NoneBotConfig):

self.fastapi_config: Config = Config(**config.dict())

self._lifespan = Lifespan()

self._server_app = FastAPI(
lifespan=self._lifespan_manager,
openapi_url=self.fastapi_config.fastapi_openapi_url,
Expand Down Expand Up @@ -155,14 +151,6 @@ async def _handle(websocket: WebSocket) -> None:
name=setup.name,
)

@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_startup(func)

@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_shutdown(func)

@contextlib.asynccontextmanager
async def _lifespan_manager(self, app: FastAPI):
await self._lifespan.startup()
Expand Down
14 changes: 0 additions & 14 deletions nonebot/drivers/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver

from ._lifespan import LIFESPAN_FUNC, Lifespan

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
Expand All @@ -35,8 +33,6 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)

self._lifespan = Lifespan()

self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False

Expand All @@ -52,16 +48,6 @@ def logger(self):
"""none driver 使用的 logger"""
return logger

@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)

@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)

@override
def run(self, *args, **kwargs):
"""启动 none driver"""
Expand Down
27 changes: 3 additions & 24 deletions nonebot/drivers/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,7 @@
import asyncio
from functools import wraps
from typing_extensions import override
from typing import (
Any,
Dict,
List,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Coroutine,
cast,
)
from typing import Any, Dict, List, Tuple, Union, Optional, cast

from pydantic import BaseSettings

Expand Down Expand Up @@ -57,8 +46,6 @@
"Install with pip: `pip install nonebot2[quart]`"
) from e

_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])


def catch_closed(func):
@wraps(func)
Expand Down Expand Up @@ -102,6 +89,8 @@ def __init__(self, env: Env, config: NoneBotConfig):
self._server_app = Quart(
self.__class__.__qualname__, **self.quart_config.quart_extra
)
self._server_app.before_serving(self._lifespan.startup)
self._server_app.after_serving(self._lifespan.shutdown)

@property
@override
Expand Down Expand Up @@ -150,16 +139,6 @@ async def _handle() -> None:
view_func=_handle,
)

@override
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.before_serving(func) # type: ignore

@override
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.after_serving(func) # type: ignore

@override
def run(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class Lifespan:
def __init__(self) -> None:
self._startup_funcs: List[LIFESPAN_FUNC] = []
self._ready_funcs: List[LIFESPAN_FUNC] = []
self._shutdown_funcs: List[LIFESPAN_FUNC] = []

def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
Expand All @@ -21,6 +22,10 @@
self._shutdown_funcs.append(func)
return func

def _on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
ProgramRipper marked this conversation as resolved.
Show resolved Hide resolved
self._ready_funcs.append(func)
return func

Check warning on line 27 in nonebot/internal/driver/_lifespan.py

View check run for this annotation

Codecov / codecov/patch

nonebot/internal/driver/_lifespan.py#L26-L27

Added lines #L26 - L27 were not covered by tests

@staticmethod
async def _run_lifespan_func(
funcs: List[LIFESPAN_FUNC],
Expand All @@ -35,6 +40,9 @@
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)

if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)

Check warning on line 44 in nonebot/internal/driver/_lifespan.py

View check run for this annotation

Codecov / codecov/patch

nonebot/internal/driver/_lifespan.py#L44

Added line #L44 was not covered by tests

async def shutdown(self) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
Expand Down
18 changes: 9 additions & 9 deletions nonebot/internal/driver/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
from typing_extensions import TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from typing import TYPE_CHECKING, Any, Set, Dict, Type, AsyncGenerator

from nonebot.log import logger
from nonebot.config import Env, Config
Expand All @@ -16,6 +16,7 @@
T_BotDisconnectionHook,
)

from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, env: Env, config: Config):
"""全局配置对象"""
self._bots: Dict[str, "Bot"] = {}
self._bot_tasks: Set[asyncio.Task] = set()
self._lifespan = Lifespan()

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -100,15 +102,13 @@ def run(self, *args, **kwargs):

self.on_shutdown(self._cleanup)

@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动器启动时执行的函数"""
raise NotImplementedError
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)

@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
"""注册一个在驱动器停止时执行的函数"""
raise NotImplementedError
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)

@classmethod
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:
Expand Down
26 changes: 19 additions & 7 deletions tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from nonebot.params import Depends
from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed
from nonebot.drivers._lifespan import Lifespan
from nonebot.drivers import (
URL,
Driver,
Expand All @@ -25,34 +24,47 @@


@pytest.mark.asyncio
async def test_lifespan():
lifespan = Lifespan()
async def test_lifespan(driver: Driver):
lifespan = driver._lifespan
ProgramRipper marked this conversation as resolved.
Show resolved Hide resolved

start_log = []
ready_log = []
shutdown_log = []

@lifespan.on_startup
@driver.on_startup
async def _startup1():
assert start_log == []
start_log.append(1)

@lifespan.on_startup
@driver.on_startup
async def _startup2():
assert start_log == [1]
start_log.append(2)

@lifespan.on_shutdown
@lifespan.on_startup
def _ready1():
assert start_log == [1, 2]
assert ready_log == []
ready_log.append(1)

@lifespan.on_startup
def _ready2():
assert ready_log == [1]
ready_log.append(2)
ProgramRipper marked this conversation as resolved.
Show resolved Hide resolved

@driver.on_shutdown
async def _shutdown1():
assert shutdown_log == []
shutdown_log.append(1)

@lifespan.on_shutdown
@driver.on_shutdown
async def _shutdown2():
assert shutdown_log == [1]
shutdown_log.append(2)

async with lifespan:
assert start_log == [1, 2]
assert ready_log == [1, 2]

assert shutdown_log == [1, 2]

Expand Down
Loading