From f6db2386035c0f37623227065e7983083382ad21 Mon Sep 17 00:00:00 2001 From: Helder Sepulveda Date: Wed, 8 May 2024 12:09:05 -0400 Subject: [PATCH] Refactor RESTManager and add config.write when port changes Refactor RESTManager and add config.write when port changes --- .../components/restapi/rest/rest_manager.py | 31 +++++++++---------- .../rest/tests/test_events_endpoint.py | 5 +-- .../restapi/rest/tests/test_rest_manager.py | 5 +-- .../components/restapi/restapi_component.py | 2 +- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/tribler/core/components/restapi/rest/rest_manager.py b/src/tribler/core/components/restapi/rest/rest_manager.py index 60c9a3cad3f..46db9b4dffa 100644 --- a/src/tribler/core/components/restapi/rest/rest_manager.py +++ b/src/tribler/core/components/restapi/rest/rest_manager.py @@ -18,7 +18,7 @@ RESTResponse, ) from tribler.core.components.restapi.rest.root_endpoint import RootEndpoint -from tribler.core.components.restapi.rest.settings import APISettings +from tribler.core.config.tribler_config import TriblerConfig from tribler.core.utilities.network_utils import default_network_utils from tribler.core.utilities.process_manager import get_global_process_manager from tribler.core.version import version_id @@ -83,7 +83,7 @@ class RESTManager: This class is responsible for managing the startup and closing of the Tribler HTTP API. """ - def __init__(self, config: APISettings, root_endpoint: RootEndpoint, state_dir=None, shutdown_timeout: int = 10): + def __init__(self, config: TriblerConfig, root_endpoint: RootEndpoint, shutdown_timeout: int = 10): super().__init__() self._logger = logging.getLogger(self.__class__.__name__) self.root_endpoint = root_endpoint @@ -91,8 +91,6 @@ def __init__(self, config: APISettings, root_endpoint: RootEndpoint, state_dir=N self.site: Optional[web.TCPSite] = None self.site_https: Optional[web.TCPSite] = None self.config = config - self.state_dir = state_dir - self.shutdown_timeout = shutdown_timeout def get_endpoint(self, name): @@ -101,8 +99,9 @@ def get_endpoint(self, name): def set_api_port(self, api_port: int): default_network_utils.remember(api_port) - if self.config.http_port != api_port: - self.config.http_port = api_port + if self.config.api.http_port != api_port: + self.config.api.http_port = api_port + self.config.write() process_manager = get_global_process_manager() if process_manager: @@ -122,7 +121,7 @@ async def start(self): version=version_id, swagger_path='/docs' ) - if self.config.key: + if self.config.api.key: self._logger.info('Set security scheme and apply to all endpoints') aiohttp_apispec.spec.options['security'] = [{'apiKey': []}] @@ -136,21 +135,21 @@ async def start(self): self.runner = web.AppRunner(self.root_endpoint.app, access_log=None) await self.runner.setup() - if self.config.http_enabled: + if self.config.api.http_enabled: self._logger.info('Http enabled') await self.start_http_site() - if self.config.https_enabled: + if self.config.api.https_enabled: self._logger.info('Https enabled') await self.start_https_site() - self._logger.info(f'Swagger docs: http://{self.config.http_host}:{self.config.http_port}/docs') - self._logger.info(f'Swagger JSON: http://{self.config.http_host}:{self.config.http_port}/docs/swagger.json') + self._logger.info(f'Swagger docs: http://{self.config.api.http_host}:{self.config.api.http_port}/docs') + self._logger.info(f'Swagger JSON: http://{self.config.api.http_host}:{self.config.api.http_port}/docs/swagger.json') async def start_http_site(self): - api_port = max(self.config.http_port, 0) # if the value in config is -1 we convert it to 0 + api_port = max(self.config.api.http_port, 0) # if the value in config is -1 we convert it to 0 - self.site = web.TCPSite(self.runner, self.config.http_host, api_port, shutdown_timeout=self.shutdown_timeout) + self.site = web.TCPSite(self.runner, self.config.api.http_host, api_port, shutdown_timeout=self.shutdown_timeout) self._logger.info(f"Starting HTTP REST API server on port {api_port}...") try: @@ -168,11 +167,11 @@ async def start_http_site(self): async def start_https_site(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - cert = self.config.get_path_as_absolute('https_certfile', self.state_dir) + cert = self.config.api.get_path_as_absolute('https_certfile', self.config.state_dir) ssl_context.load_cert_chain(cert) - port = self.config.https_port - self.site_https = web.TCPSite(self.runner, self.config.https_host, port, ssl_context=ssl_context) + port = self.config.api.https_port + self.site_https = web.TCPSite(self.runner, self.config.api.https_host, port, ssl_context=ssl_context) await self.site_https.start() self._logger.info("Started HTTPS REST API: %s", self.site_https.name) diff --git a/src/tribler/core/components/restapi/rest/tests/test_events_endpoint.py b/src/tribler/core/components/restapi/rest/tests/test_events_endpoint.py index 3a19f6516c7..f5b27f3daec 100644 --- a/src/tribler/core/components/restapi/rest/tests/test_events_endpoint.py +++ b/src/tribler/core/components/restapi/rest/tests/test_events_endpoint.py @@ -51,9 +51,10 @@ async def fixture_rest_manager(api_port, tmp_path, events_endpoint): config = TriblerConfig() config.api.http_enabled = True config.api.http_port = api_port + config.set_state_dir(tmp_path) root_endpoint = RootEndpoint(middlewares=[ApiKeyMiddleware(config.api.key), error_middleware]) root_endpoint.add_endpoint('/events', events_endpoint) - rest_manager = RESTManager(config=config.api, root_endpoint=root_endpoint, state_dir=tmp_path) + rest_manager = RESTManager(config=config, root_endpoint=root_endpoint) await rest_manager.start() yield rest_manager @@ -62,7 +63,7 @@ async def fixture_rest_manager(api_port, tmp_path, events_endpoint): async def open_events_socket(rest_manager_, connected_event, events_up): global messages_to_wait_for - port = rest_manager_.config.http_port + port = rest_manager_.config.api.http_port url = f'http://localhost:{port}/events' headers = {'User-Agent': 'Tribler ' + version_id} diff --git a/src/tribler/core/components/restapi/rest/tests/test_rest_manager.py b/src/tribler/core/components/restapi/rest/tests/test_rest_manager.py index 7328c8bca06..a920d596ee5 100644 --- a/src/tribler/core/components/restapi/rest/tests/test_rest_manager.py +++ b/src/tribler/core/components/restapi/rest/tests/test_rest_manager.py @@ -34,6 +34,7 @@ def api_port_fixture(free_port): @pytest.fixture(name='rest_manager') async def rest_manager_fixture(request, tribler_config, api_port, tmp_path): config = tribler_config + config.set_state_dir(tmp_path) api_key_marker = request.node.get_closest_marker("api_key") if api_key_marker is not None: tribler_config.api.key = api_key_marker.args[0] @@ -49,7 +50,7 @@ async def rest_manager_fixture(request, tribler_config, api_port, tmp_path): tribler_config.api.http_port = api_port root_endpoint = RootEndpoint(middlewares=[ApiKeyMiddleware(config.api.key), error_middleware]) root_endpoint.add_endpoint('/settings', SettingsEndpoint(config)) - rest_manager = RESTManager(config=config.api, root_endpoint=root_endpoint, state_dir=tmp_path) + rest_manager = RESTManager(config=config, root_endpoint=root_endpoint) await rest_manager.start() yield rest_manager await rest_manager.stop() @@ -69,7 +70,7 @@ async def test_api_key_disabled(rest_manager, api_port): @pytest.mark.api_key('0' * 32) async def test_api_key_success(rest_manager, api_port): - api_key = rest_manager.config.key + api_key = rest_manager.config.api.key await do_real_request(api_port, 'settings?apikey=' + api_key) await do_real_request(api_port, 'settings', headers={'X-Api-Key': api_key}) diff --git a/src/tribler/core/components/restapi/restapi_component.py b/src/tribler/core/components/restapi/restapi_component.py index 2cb104f2053..129baa4628d 100644 --- a/src/tribler/core/components/restapi/restapi_component.py +++ b/src/tribler/core/components/restapi/restapi_component.py @@ -109,7 +109,7 @@ async def run(self): self.root_endpoint.add_endpoint('/ipv8', ipv8_root_endpoint) # Note: AIOHTTP endpoints cannot be added after the app has been started! - rest_manager = RESTManager(config=config.api, root_endpoint=self.root_endpoint, state_dir=config.state_dir) + rest_manager = RESTManager(config=config, root_endpoint=self.root_endpoint) await rest_manager.start() self.rest_manager = rest_manager