Skip to content

Commit

Permalink
feat: add global middlewares, which are middleware executed before or… (
Browse files Browse the repository at this point in the history
sparckles#498)

* feat: add global middlewares, which are middleware executed before or after every request

* Update src/server.rs

* Update src/server.rs

---------

Co-authored-by: Sanskar Jethi <[email protected]>
  • Loading branch information
AntoineRR and sansyrox authored May 21, 2023
1 parent 405c761 commit 7867f3e
Show file tree
Hide file tree
Showing 18 changed files with 419 additions and 179 deletions.
16 changes: 15 additions & 1 deletion docs/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ You can use both sync and async functions for middlewares!
```python
@app.before_request("/")
async def hello_before_request(request: Request):
request.headers["before"] = "sync_before_request"
request.headers["before"] = "async_before_request"
print(request)


Expand All @@ -361,6 +361,20 @@ def hello_after_request(response: Response):
print(response)
```

Middlewares can be bound to a route or run before/after every request:

```python
# This middleware runs before all requests
@app.before_request()
async def global_before_request(request: Request):
request.headers["before"] = "global_before_request"

# This middleware runs only before requests to "/your/route"
@app.before_request("/your/route")
async def route_before_request(request: Request):
request.headers["before"] = "route_before_request"
```

## MultiCore Scaling

To run Robyn across multiple cores, you can use the following command:
Expand Down
24 changes: 24 additions & 0 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,30 @@ def shutdown_handler():

# ===== Middlewares =====

# --- Global ---


@app.before_request()
def global_before_request(request: Request):
request.headers["global_before"] = "global_before_request"
return request


@app.after_request()
def global_after_request(response: Response):
response.headers["global_after"] = "global_after_request"
return response


@app.get("/sync/global/middlewares")
def sync_global_middlewares(request: Request):
assert "global_before" in request.headers
assert request.headers["global_before"] == "global_before_request"
return "sync global middlewares"


# --- Route specific ---


@app.before_request("/sync/middlewares")
def sync_before_request(request: Request):
Expand Down
2 changes: 2 additions & 0 deletions integration_tests/helpers/http_methods_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

def check_response(response: requests.Response, expected_status_code: int):
assert response.status_code == expected_status_code
assert "global_after" in response.headers
assert response.headers["global_after"] == "global_after_request"
assert "server" in response.headers
assert response.headers["server"] == "robyn"

Expand Down
9 changes: 9 additions & 0 deletions integration_tests/test_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,12 @@ def test_middlewares(function_type: str, session):
assert "after" in r.headers
assert r.headers["after"] == f"{function_type}_after_request"
assert r.text == f"{function_type} middlewares after"


@pytest.mark.benchmark
def test_global_middleware(session):
r = get("/sync/global/middlewares")
assert "global_before" not in r.headers
assert "global_after" in r.headers
assert r.headers["global_after"] == "global_after_request"
assert r.text == "sync global middlewares"
62 changes: 39 additions & 23 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import multiprocess as mp
import os
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple
from nestd import get_all_nested

from robyn.argument_parser import Config
Expand All @@ -13,8 +13,8 @@
from robyn.logger import logger
from robyn.processpool import run_processes
from robyn.responses import jsonify, serve_file, serve_html
from robyn.robyn import FunctionInfo, Request, Response, get_version
from robyn.router import MiddlewareRouter, Router, WebSocketRouter
from robyn.robyn import FunctionInfo, HttpMethod, Request, Response, get_version
from robyn.router import MiddlewareRouter, MiddlewareType, Router, WebSocketRouter
from robyn.types import Directory, Header
from robyn.status_codes import StatusCodes
from robyn.ws import WS
Expand Down Expand Up @@ -54,7 +54,9 @@ def __init__(self, file_object: str, config: Config = Config()) -> None:
self.directories: List[Directory] = []
self.event_handlers = {}

def _add_route(self, route_type, endpoint, handler, is_const=False):
def _add_route(
self, route_type: HttpMethod, endpoint: str, handler: Callable, is_const=False
):
"""
This is base handler for all the decorators
Expand All @@ -67,23 +69,27 @@ def _add_route(self, route_type, endpoint, handler, is_const=False):
"""
return self.router.add_route(route_type, endpoint, handler, is_const)

def before_request(self, endpoint: str) -> Callable[..., None]:
def before_request(self, endpoint: Optional[str] = None) -> Callable[..., None]:
"""
You can use the @app.before_request decorator to call a method before routing to the specified endpoint
:param endpoint str: endpoint to server the route
"""

return self.middleware_router.add_before_request(endpoint)
return self.middleware_router.add_middleware(
MiddlewareType.BEFORE_REQUEST, endpoint
)

def after_request(self, endpoint: str) -> Callable[..., None]:
def after_request(self, endpoint: Optional[str] = None) -> Callable[..., None]:
"""
You can use the @app.after_request decorator to call a method after routing to the specified endpoint
:param endpoint str: endpoint to server the route
"""

return self.middleware_router.add_after_request(endpoint)
return self.middleware_router.add_middleware(
MiddlewareType.AFTER_REQUEST, endpoint
)

def add_directory(
self,
Expand Down Expand Up @@ -140,7 +146,8 @@ def start(self, url: str = "127.0.0.1", port: int = 8080):
self.directories,
self.request_headers,
self.router.get_routes(),
self.middleware_router.get_routes(),
self.middleware_router.get_global_middlewares(),
self.middleware_router.get_route_middlewares(),
self.web_socket_router.get_routes(),
self.event_handlers,
self.config.workers,
Expand All @@ -155,15 +162,24 @@ def add_view(self, endpoint: str, view: Callable, const: bool = False):
:param endpoint str: endpoint for the route added
:param handler function: represents the function passed as a parent handler for single route with different route types
"""
http_methods = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}

def get_functions(view):
http_methods = {
"GET": HttpMethod.GET,
"POST": HttpMethod.POST,
"PUT": HttpMethod.PUT,
"DELETE": HttpMethod.DELETE,
"PATCH": HttpMethod.PATCH,
"HEAD": HttpMethod.HEAD,
"OPTIONS": HttpMethod.OPTIONS,
}

def get_functions(view) -> List[Tuple[HttpMethod, Callable]]:
functions = get_all_nested(view)
output = []
for name, handler in functions:
route_type = name.upper()
if route_type in http_methods:
output.append((route_type, handler))
method = http_methods.get(route_type)
if method is not None:
output.append((method, handler))
return output

handlers = get_functions(view)
Expand All @@ -190,7 +206,7 @@ def get(self, endpoint: str, const: bool = False):
"""

def inner(handler):
return self._add_route("GET", endpoint, handler, const)
return self._add_route(HttpMethod.GET, endpoint, handler, const)

return inner

Expand All @@ -202,7 +218,7 @@ def post(self, endpoint: str):
"""

def inner(handler):
return self._add_route("POST", endpoint, handler)
return self._add_route(HttpMethod.POST, endpoint, handler)

return inner

Expand All @@ -214,7 +230,7 @@ def put(self, endpoint: str):
"""

def inner(handler):
return self._add_route("PUT", endpoint, handler)
return self._add_route(HttpMethod.PUT, endpoint, handler)

return inner

Expand All @@ -226,7 +242,7 @@ def delete(self, endpoint: str):
"""

def inner(handler):
return self._add_route("DELETE", endpoint, handler)
return self._add_route(HttpMethod.DELETE, endpoint, handler)

return inner

Expand All @@ -238,7 +254,7 @@ def patch(self, endpoint: str):
"""

def inner(handler):
return self._add_route("PATCH", endpoint, handler)
return self._add_route(HttpMethod.PATCH, endpoint, handler)

return inner

Expand All @@ -250,7 +266,7 @@ def head(self, endpoint: str):
"""

def inner(handler):
return self._add_route("HEAD", endpoint, handler)
return self._add_route(HttpMethod.HEAD, endpoint, handler)

return inner

Expand All @@ -262,7 +278,7 @@ def options(self, endpoint: str):
"""

def inner(handler):
return self._add_route("OPTIONS", endpoint, handler)
return self._add_route(HttpMethod.OPTIONS, endpoint, handler)

return inner

Expand All @@ -274,7 +290,7 @@ def connect(self, endpoint: str):
"""

def inner(handler):
return self._add_route("CONNECT", endpoint, handler)
return self._add_route(HttpMethod.CONNECT, endpoint, handler)

return inner

Expand All @@ -286,7 +302,7 @@ def trace(self, endpoint: str):
"""

def inner(handler):
return self._add_route("TRACE", endpoint, handler)
return self._add_route(HttpMethod.TRACE, endpoint, handler)

return inner

Expand Down
26 changes: 17 additions & 9 deletions robyn/processpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from robyn.events import Events
from robyn.robyn import FunctionInfo, Server, SocketHeld
from robyn.router import MiddlewareRoute, Route
from robyn.router import GlobalMiddleware, RouteMiddleware, Route
from robyn.types import Directory, Header
from robyn.ws import WS

Expand All @@ -18,7 +18,8 @@ def run_processes(
directories: List[Directory],
request_headers: List[Header],
routes: List[Route],
middlewares: List[MiddlewareRoute],
global_middlewares: List[GlobalMiddleware],
route_middlewares: List[RouteMiddleware],
web_sockets: Dict[str, WS],
event_handlers: Dict[Events, FunctionInfo],
workers: int,
Expand All @@ -31,7 +32,8 @@ def run_processes(
directories,
request_headers,
routes,
middlewares,
global_middlewares,
route_middlewares,
web_sockets,
event_handlers,
socket,
Expand Down Expand Up @@ -59,7 +61,8 @@ def init_processpool(
directories: List[Directory],
request_headers: List[Header],
routes: List[Route],
middlewares: List[MiddlewareRoute],
global_middlewares: List[GlobalMiddleware],
route_middlewares: List[RouteMiddleware],
web_sockets: Dict[str, WS],
event_handlers: Dict[Events, FunctionInfo],
socket: SocketHeld,
Expand All @@ -73,7 +76,8 @@ def init_processpool(
directories,
request_headers,
routes,
middlewares,
global_middlewares,
route_middlewares,
web_sockets,
event_handlers,
socket,
Expand All @@ -91,7 +95,8 @@ def init_processpool(
directories,
request_headers,
routes,
middlewares,
global_middlewares,
route_middlewares,
web_sockets,
event_handlers,
copied_socket,
Expand Down Expand Up @@ -125,7 +130,8 @@ def spawn_process(
directories: List[Directory],
request_headers: List[Header],
routes: List[Route],
middlewares: List[MiddlewareRoute],
global_middlewares: List[GlobalMiddleware],
route_middlewares: List[RouteMiddleware],
web_sockets: Dict[str, WS],
event_handlers: Dict[Events, FunctionInfo],
socket: SocketHeld,
Expand Down Expand Up @@ -166,8 +172,10 @@ def spawn_process(
route_type, endpoint, function, is_const = route
server.add_route(route_type, endpoint, function, is_const)

for middleware_route in middlewares:
route_type, endpoint, function = middleware_route
for middleware_type, middleware_function in global_middlewares:
server.add_global_middleware(middleware_type, middleware_function)

for route_type, endpoint, function in route_middlewares:
server.add_middleware_route(route_type, endpoint, function)

if Events.STARTUP in event_handlers:
Expand Down
Loading

0 comments on commit 7867f3e

Please sign in to comment.