From ed91ae3f6f525b6b4ec4002cb0de3e398aea0668 Mon Sep 17 00:00:00 2001 From: JurgenR <1249228+JurgenR@users.noreply.github.com> Date: Thu, 5 Oct 2023 16:54:17 +0200 Subject: [PATCH 1/2] feat: split peer module --- README.rst | 6 +- src/aioslsk/client.py | 136 ++++-- src/aioslsk/distributed.py | 423 +++++++++++++++++ src/aioslsk/events.py | 2 +- src/aioslsk/network/connection.py | 11 +- src/aioslsk/peer.py | 516 +-------------------- src/aioslsk/search/__init__.py | 0 src/aioslsk/search/manager.py | 328 +++++++++++++ src/aioslsk/{search.py => search/model.py} | 2 +- src/aioslsk/server.py | 109 +---- src/aioslsk/shares/manager.py | 2 +- src/aioslsk/state.py | 8 +- src/aioslsk/transfer/manager.py | 12 +- tests/__init__.py | 0 tests/unit/search/test_search_manager.py | 131 ++++++ tests/unit/test_peer.py | 51 +- tests/unit/test_server_manager.py | 37 +- 17 files changed, 1034 insertions(+), 740 deletions(-) create mode 100644 src/aioslsk/distributed.py create mode 100644 src/aioslsk/search/__init__.py create mode 100644 src/aioslsk/search/manager.py rename src/aioslsk/{search.py => search/model.py} (98%) create mode 100644 tests/__init__.py create mode 100644 tests/unit/search/test_search_manager.py diff --git a/README.rst b/README.rst index 4a4d2f2c..93138082 100644 --- a/README.rst +++ b/README.rst @@ -53,7 +53,7 @@ User +------------------------------+--------+------------------------------------------------------------------------+-----------+ | credentials.info.description | string | Personal description, will be returned when a peer request info on you | | +------------------------------+--------+------------------------------------------------------------------------+-----------+ -| credentials.info.pciture | string | Picture, will be returned when a peer request info on you | | +| credentials.info.picture | string | Picture, will be returned when a peer request info on you | | +------------------------------+--------+------------------------------------------------------------------------+-----------+ @@ -99,9 +99,7 @@ Sharing +------------------------------------+---------------+-----------------------------------------------------------------------------------+-----------+ | sharing.download | string | Directory to which files will be downloaded to | | +------------------------------------+---------------+-----------------------------------------------------------------------------------+-----------+ -| sharing.directories | array[object] | List of shared directories: | | -+------------------------------------+---------------+-----------------------------------------------------------------------------------+-----------+ -| sharing.index.store_interval | integer | Shared items index automatically gets stored, this parameter defines the interval | 120 | +| sharing.directories | array[object] | List of shared directories (see structure for each entry below) | | +------------------------------------+---------------+-----------------------------------------------------------------------------------+-----------+ The `sharing.directories` list contains objects which have the following parameters: diff --git a/src/aioslsk/client.py b/src/aioslsk/client.py index 5a9253e1..357e538a 100644 --- a/src/aioslsk/client.py +++ b/src/aioslsk/client.py @@ -4,6 +4,7 @@ from typing import List, Union from .configuration import Configuration +from .distributed import DistributedNetwork from .events import EventBus, InternalEventBus from .shares.cache import ( SharesShelveCache, @@ -14,7 +15,8 @@ from .network.network import Network from .peer import PeerManager from .server import ServerManager -from .search import SearchRequest, SearchResult +from .search.manager import SearchManager +from .search.model import SearchRequest, SearchResult from .state import State from .settings import Settings from .transfer.cache import TransferCache, TransferShelveCache @@ -47,47 +49,21 @@ def __init__(self, configuration: Configuration, settings_name: str = None, even self.state: State = State() - self.network: Network = Network( - self.state, - self.settings, - self._internal_events, - self._stop_event - ) + self.network: Network = self.create_network() shares_cache: SharesCache = SharesShelveCache(self.configuration.data_directory) - self.shares_manager: SharesManager = SharesManager( - self.settings, - self._internal_events, - cache=shares_cache + self.shares_manager: SharesManager = self.create_shares_manager( + shares_cache ) transfer_cache: TransferCache = TransferShelveCache(self.configuration.data_directory) - self.transfer_manager: TransferManager = TransferManager( - self.state, - self.settings, - self.events, - self._internal_events, - self.shares_manager, - self.network, - cache=transfer_cache - ) - self.peer_manager: PeerManager = PeerManager( - self.state, - self.settings, - self.events, - self._internal_events, - self.shares_manager, - self.transfer_manager, - self.network - ) - self.server_manager: ServerManager = ServerManager( - self.state, - self.settings, - self.events, - self._internal_events, - self.shares_manager, - self.network + self.transfer_manager: TransferManager = self.create_transfer_manager( + transfer_cache ) + self.peer_manager: PeerManager = self.create_peer_manager() + self.search_manager: SearchManager = self.create_search_manager() + self.distributed_network: DistributedNetwork = self.create_distributed_network() + self.server_manager: ServerManager = self.create_server_manager() @property def event_loop(self): @@ -160,7 +136,11 @@ async def stop(self): await self.network.disconnect() - cancelled_tasks = self.peer_manager.stop() + self.transfer_manager.stop() + cancelled_tasks = ( + self.transfer_manager.stop() + + self.search_manager.stop() + + self.distributed_network.stop() + ) await asyncio.gather(*cancelled_tasks, return_exceptions=True) self.shares_manager.write_cache() @@ -296,27 +276,27 @@ async def search(self, query: str) -> SearchRequest: """Performs a search, returns the generated ticket number for the search """ logger.info(f"Starting search for query: {query}") - return await self.server_manager.search(query) + return await self.search_manager.search(query) async def search_user(self, query: str, user: Union[str, User]) -> SearchRequest: username = user.name if isinstance(user, User) else user - return await self.server_manager.search_user(username, query) + return await self.search_manager.search_user(username, query) async def search_room(self, query: str, room: Union[str, Room]) -> SearchRequest: room_name = room.name if isinstance(room, Room) else room - return await self.server_manager.search_room(room_name, query) + return await self.search_manager.search_room(room_name, query) def get_search_request_by_ticket(self, ticket: int) -> SearchRequest: """Returns a search request with given ticket""" - return self.state.search_requests[ticket] + return self.search_manager.search_requests[ticket] def get_search_results_by_ticket(self, ticket: int) -> List[SearchResult]: """Returns all search results for given ticket""" - return self.state.search_requests[ticket].results + return self.search_manager.search_requests[ticket].results def remove_search_request_by_ticket(self, ticket: int) -> SearchRequest: """Removes a search request for given ticket""" - return self.state.search_requests.pop(ticket) + return self.search_manager.search_requests.pop(ticket) async def get_user_stats(self, user: Union[str, User]): username = user.name if isinstance(user, User) else user @@ -338,3 +318,73 @@ async def get_user_shares(self, user: Union[str, User]): async def get_user_directory(self, user: Union[str, User], directory: List[str]): username = user.name if isinstance(user, User) else user await self.peer_manager.get_user_directory(username, directory) + + # Creation methods + + def create_network(self) -> Network: + return Network( + self.state, + self.settings, + self._internal_events, + self._stop_event + ) + + def create_shares_manager(self, cache: SharesCache) -> SharesManager: + return SharesManager( + self.settings, + self._internal_events, + cache=cache + ) + + def create_transfer_manager(self, cache: TransferCache) -> TransferManager: + return TransferManager( + self.state, + self.settings, + self.events, + self._internal_events, + self.shares_manager, + self.network, + cache=cache + ) + + def create_search_manager(self) -> SearchManager: + return SearchManager( + self.state, + self.settings, + self.events, + self._internal_events, + self.shares_manager, + self.transfer_manager, + self.network + ) + + def create_server_manager(self) -> ServerManager: + return ServerManager( + self.state, + self.settings, + self.events, + self._internal_events, + self.shares_manager, + self.network + ) + + def create_peer_manager(self) -> PeerManager: + return PeerManager( + self.state, + self.settings, + self.events, + self._internal_events, + self.shares_manager, + self.transfer_manager, + self.network + ) + + def create_distributed_network(self) -> DistributedNetwork: + return DistributedNetwork( + self.state, + self.settings, + self.events, + self._internal_events, + self.shares_manager, + self.network + ) diff --git a/src/aioslsk/distributed.py b/src/aioslsk/distributed.py new file mode 100644 index 00000000..0776cfe3 --- /dev/null +++ b/src/aioslsk/distributed.py @@ -0,0 +1,423 @@ +import asyncio +from dataclasses import dataclass +from functools import partial +import logging +from typing import List, Optional, Union, Tuple + +from .network.connection import ( + CloseReason, + ConnectionState, + PeerConnection, + PeerConnectionType, + ServerConnection, +) +from .events import ( + on_message, + build_message_map, + EventBus, + InternalEventBus, + ConnectionStateChangedEvent, + LoginSuccessEvent, + PeerInitializedEvent, + MessageReceivedEvent, +) +from .protocol.messages import ( + AcceptChildren, + BranchLevel, + BranchRoot, + ToggleParentSearch, + DistributedChildDepth, + DistributedBranchLevel, + DistributedBranchRoot, + DistributedSearchRequest, + DistributedServerSearchRequest, + MessageDataclass, + PotentialParents, + ServerSearchRequest, +) +from .network.network import Network +from .settings import Settings +from .shares.manager import SharesManager +from .state import State +from .utils import task_counter, ticket_generator + + +logger = logging.getLogger(__name__) + + +@dataclass +class DistributedPeer: + username: str + connection: PeerConnection + branch_level: int = None + branch_root: str = None + child_depth: int = None + + +class DistributedNetwork: + """Class responsible for handling the distributed network""" + + def __init__( + self, state: State, settings: Settings, + event_bus: EventBus, internal_event_bus: InternalEventBus, + shares_manager: SharesManager, + network: Network): + self._state: State = state + self._settings: Settings = settings + self._event_bus: EventBus = event_bus + self._internal_event_bus: InternalEventBus = internal_event_bus + self._network: Network = network + self._shares_manager: SharesManager = shares_manager + + self._ticket_generator = ticket_generator() + + self.parent: DistributedPeer = None + """Distributed parent. This variable is `None` if we are looking for + parents + """ + self.children: List[DistributedPeer] = [] + self.potential_parents: List[str] = [] + self.distributed_peers: List[DistributedPeer] = [] + + self._internal_event_bus.register( + PeerInitializedEvent, self._on_peer_connection_initialized) + self._internal_event_bus.register( + ConnectionStateChangedEvent, self._on_state_changed) + self._internal_event_bus.register( + MessageReceivedEvent, self._on_message_received) + self._internal_event_bus.register( + LoginSuccessEvent, self._on_login_success) + + self.MESSAGE_MAP = build_message_map(self) + + self._potential_parent_tasks: List[asyncio.Task] = [] + + def _get_advertised_branch_values(self) -> Tuple[str, int]: + """Returns the advertised branch values. These values are to be sent to + the children and the server to let them know where we are in the + distributed tree. + + If no parent: + * level = 0 + * root = our own username + + If we are the root: + * level = 0 + * root = our own username + + If we have a parent: + * level = level parent advertised + 1 + * root = whatever our parent sent us initially + """ + username = self._settings.get('credentials.username') + if self.parent: + # We are the branch root + if self.parent.branch_root == username: + return username, 0 + else: + return self.parent.branch_root, self.parent.branch_level + 1 + + return username, 0 + + def get_distributed_peer(self, username: str, connection: PeerConnection) -> Optional[DistributedPeer]: + for peer in self.distributed_peers: + if peer.username == username and peer.connection == connection: + return peer + + async def _set_parent(self, peer: DistributedPeer): + logger.info(f"set parent : {peer}") + self.parent = peer + + self._cancel_potential_parent_tasks() + # Cancel all tasks related to potential parents and disconnect all other + # distributed connections except for children and the parent connection + # Other distributed connection from the parent that we have should also + # be disconnected + distributed_connections = [ + distributed_peer.connection for distributed_peer in self.distributed_peers + if distributed_peer in [self.parent, ] + self.children + ] + for peer_connection in self._network.peer_connections: + if peer_connection.connection_type == PeerConnectionType.DISTRIBUTED: + if peer_connection not in distributed_connections: + asyncio.create_task( + peer_connection.disconnect(reason=CloseReason.REQUESTED), + name=f'disconnect-distributed-{task_counter()}' + ) + + await self._notify_server_of_parent() + await self._notify_children_of_branch_values() + + async def _check_if_new_parent(self, peer: DistributedPeer): + """Called after BranchRoot or BranchLevel, checks if all information is + complete for this peer/connection to become a parent and makes it a + parent if we don't have one, otherwise just close the connection. + """ + # Explicit None checks because we can get 0 as branch level + if peer.branch_level is not None and peer.branch_root is not None: + if not self.parent: + await self._set_parent(peer) + else: + await peer.connection.disconnect(reason=CloseReason.REQUESTED) + else: + logger.debug(f"{self._settings.get('credentials.username')} : not enough info for parent : {peer}") + + async def _unset_parent(self): + logger.debug(f"unset parent {self.parent!r}") + + self.parent = None + + username = self._settings.get('credentials.username') + await self._notify_server_of_parent() + + # TODO: What happens to the children when we lose our parent is still + # unclear + await self.send_messages_to_children( + DistributedBranchLevel.Request(0), + DistributedBranchRoot.Request(username) + ) + + async def _notify_server_of_parent(self): + """Notifies the server of our parent or if we don't have any, notify the + server that we are looking for one + """ + root, level = self._get_advertised_branch_values() + logger.info(f"notifying server of our parent : level={level} root={root}") + + messages = [ + BranchLevel.Request(level), + BranchRoot.Request(root) + ] + + if self.parent: + logger.info("notifying server we are not looking for parent") + messages.extend([ + ToggleParentSearch.Request(False), + AcceptChildren.Request(True) + ]) + else: + logger.info("notifying server we are looking for parent") + # The original Windows client sends out the child depth (=0) and the + # ParentIP + messages.extend([ + ToggleParentSearch.Request(True), + AcceptChildren.Request(True) + ]) + + await self._network.send_server_messages(*messages) + + async def _notify_children_of_branch_values(self): + root, level = self._get_advertised_branch_values() + await self.send_messages_to_children( + DistributedBranchLevel.Request(level), + DistributedBranchRoot.Request(root) + ) + + async def _check_if_new_child(self, peer: DistributedPeer): + """Potentially adds a distributed connection to our list of children. + """ + if peer.username in self.potential_parents: + return + + await self._add_child(peer) + + async def _add_child(self, peer: DistributedPeer): + logger.debug(f"adding distributed connection as child : {peer!r}") + self.children.append(peer) + # Let the child know where we are in the distributed tree + root, level = self._get_advertised_branch_values() + await peer.connection.send_message(DistributedBranchLevel.Request(level)) + await peer.connection.send_message(DistributedBranchRoot.Request(root)) + + def _potential_parent_task_callback(self, username: str, task: asyncio.Task): + """Callback for potential parent handling task. This callback simply + logs the results and removes the task from the list + """ + try: + task.result() + + except asyncio.CancelledError: + logger.debug(f"request for potential parent cancelled (username={username})") + except Exception as exc: + logger.warning(f"request for potential parent failed : {exc!r} (username={username})") + else: + logger.info(f"request for potential parent successful (username={username})") + finally: + self._potential_parent_tasks.remove(task) + + # Server messages + + @on_message(PotentialParents.Response) + async def _on_potential_parents(self, message: PotentialParents.Response, connection: ServerConnection): + if not self._settings.get('debug.search_for_parent'): + logger.debug("ignoring PotentialParents message : searching for parent is disabled") + return + + self.potential_parents = [ + entry.username for entry in message.entries + ] + + for entry in message.entries: + task = asyncio.create_task( + self._network.create_peer_connection( + entry.username, + PeerConnectionType.DISTRIBUTED, + ip=entry.ip, + port=entry.port + ), + name=f'potential-parent-{task_counter()}' + ) + task.add_done_callback( + partial(self._potential_parent_task_callback, entry.username) + ) + self._potential_parent_tasks.append(task) + + @on_message(ServerSearchRequest.Response) + async def _on_server_search_request(self, message: ServerSearchRequest.Response, connection): + username = self._settings.get('credentials.username') + if message.username == username: + return + + if not self.parent: + # Set ourself as parent + parent = DistributedPeer( + username, + None, + branch_root=username, + branch_level=0 + ) + await self._set_parent(parent) + + for child in self.children: + child.connection.queue_messages(message) + + # Distributed messages + + @on_message(DistributedBranchLevel.Request) + async def _on_distributed_branch_level(self, message: DistributedBranchLevel.Request, connection: PeerConnection): + logger.info(f"branch level {message.level!r}: {connection!r}") + + peer = self.get_distributed_peer(connection.username, connection) + peer.branch_level = message.level + + # Branch root is not always sent in case the peer advertises branch + # level 0 because he himself is the root + if message.level == 0: + peer.branch_root = peer.username + + if peer != self.parent: + await self._check_if_new_parent(peer) + else: + logger.info(f"parent advertised new branch level : {message.level}") + await self._notify_children_of_branch_values() + + @on_message(DistributedBranchRoot.Request) + async def _on_distributed_branch_root(self, message: DistributedBranchRoot.Request, connection: PeerConnection): + logger.info(f"branch root {message.username!r}: {connection!r}") + + peer = self.get_distributed_peer(connection.username, connection) + + # When we receive branch level 0 we automatically assume the root is the + # peer who sent the sender + # Don't do anything if the branch root is what we expected it to be + if peer.branch_root == message.username: + logger.debug(f"{self._settings.get('credentials.username')} : skipping parent check") + return + + peer.branch_root = message.username + if peer != self.parent: + await self._check_if_new_parent(peer) + else: + logger.info(f"parent advertised new branch root : {message.username}") + await self._notify_children_of_branch_values() + + @on_message(DistributedChildDepth) + async def _on_distributed_child_depth(self, message: DistributedChildDepth.Request, connection: PeerConnection): + peer = self.get_distributed_peer(connection.username, connection) + peer.child_depth = message.depth + + @on_message(DistributedSearchRequest.Request) + async def _on_distributed_search_request(self, message: DistributedSearchRequest.Request, connection: PeerConnection): + await self.send_messages_to_children(message) + + @on_message(DistributedServerSearchRequest.Request) + async def _on_distributed_server_search_request(self, message: DistributedServerSearchRequest.Request, connection: PeerConnection): + if message.distributed_code != DistributedSearchRequest.Request.MESSAGE_ID: + logger.warning(f"no handling for server search request with code {message.distributed_code}") + return + + dmessage = DistributedSearchRequest.Request( + unknown=0x31, + username=message.username, + ticket=message.ticket, + query=message.query + ) + await self.send_messages_to_children(dmessage) + + async def _on_peer_connection_initialized(self, event: PeerInitializedEvent): + if event.connection.connection_type == PeerConnectionType.DISTRIBUTED: + peer = DistributedPeer(event.connection.username, event.connection) + self.distributed_peers.append(peer) + + # Only check if the peer is a potential child if the connection + # was not requested by us + if not event.requested: + await self._check_if_new_child(peer) + + async def _on_message_received(self, event: MessageReceivedEvent): + message = event.message + if message.__class__ in self.MESSAGE_MAP: + await self.MESSAGE_MAP[message.__class__](message, event.connection) + + async def _on_state_changed(self, event: ConnectionStateChangedEvent): + if not isinstance(event.connection, PeerConnection): + return + + if event.connection.connection_type != PeerConnectionType.DISTRIBUTED: + return + + if event.state == ConnectionState.CLOSED: + # Check if it was the parent that was disconnected + parent = self.parent + if parent and event.connection == parent.connection: + await self._unset_parent() + return + + # Check if it was a child + new_children = [] + for child in self.children: + if child.connection == event.connection: + logger.debug(f"removing child {child!r}") + else: + new_children.append(child) + self.children = new_children + + # Remove from the distributed connections + self.distributed_peers = [ + peer for peer in self.distributed_peers + if peer.connection != event.connection + ] + + async def _on_login_success(self, event: LoginSuccessEvent): + await self._notify_server_of_parent() + + async def send_messages_to_children(self, *messages: Union[MessageDataclass, bytes]): + for child in self.children: + child.connection.queue_messages(*messages) + + def stop(self) -> List[asyncio.Task]: + """Cancels all pending tasks + + :return: a list of tasks that have been cancelled so that they can be + awaited + """ + return self._cancel_potential_parent_tasks() + + def _cancel_potential_parent_tasks(self) -> List[asyncio.Task]: + cancelled_tasks = [] + + for task in self._potential_parent_tasks: + task.cancel() + cancelled_tasks.append(task) + + return cancelled_tasks diff --git a/src/aioslsk/events.py b/src/aioslsk/events.py index 284d504b..072dd36e 100644 --- a/src/aioslsk/events.py +++ b/src/aioslsk/events.py @@ -11,7 +11,7 @@ MessageDataclass, ItemRecommendation, ) -from .search import SearchRequest, SearchResult +from .search.model import SearchRequest, SearchResult if TYPE_CHECKING: from .network.connection import ( diff --git a/src/aioslsk/network/connection.py b/src/aioslsk/network/connection.py index f50c0911..f931feef 100644 --- a/src/aioslsk/network/connection.py +++ b/src/aioslsk/network/connection.py @@ -183,6 +183,7 @@ def __init__(self, hostname: str, port: int, network: Network, obfuscated: bool self._writer: asyncio.StreamWriter = None self._reader_task: asyncio.Task = None self.read_timeout: float = None + self._queued_messages: List[asyncio.Task] = [] def get_connecting_ip(self) -> str: """Gets the IP address being used to connect to the server/peer. @@ -233,6 +234,7 @@ async def disconnect(self, reason: CloseReason = CloseReason.UNKNOWN): await self.set_state(ConnectionState.CLOSING, close_reason=reason) logger.debug(f"{self.hostname}:{self.port} : disconnecting : {reason.name}") + self._cancel_queued_messages() try: if self._writer is not None: if not self._writer.is_closing(): @@ -383,10 +385,13 @@ async def _read(self, reader_func, timeout: float = None) -> Optional[bytes]: raise ConnectionReadError(f"{self.hostname}:{self.port} : exception during reading") from exc def queue_message(self, message: Union[bytes, MessageDataclass]) -> asyncio.Task: - return asyncio.create_task( + task = asyncio.create_task( self.send_message(message), name=f'queue-message-task-{task_counter()}' ) + self._queued_messages.append(task) + task.add_done_callback(self._queued_messages.remove) + return task def queue_messages(self, *messages: List[Union[bytes, MessageDataclass]]) -> List[asyncio.Task]: return [ @@ -394,6 +399,10 @@ def queue_messages(self, *messages: List[Union[bytes, MessageDataclass]]) -> Lis for message in messages ] + def _cancel_queued_messages(self): + for qmessage_task in self._queued_messages: + qmessage_task.cancel() + async def send_message(self, message: Union[bytes, MessageDataclass]): """Sends a message or a set of bytes over the connection. In case an object of `MessageDataClass` is provided the object will first be diff --git a/src/aioslsk/peer.py b/src/aioslsk/peer.py index 0cee6add..93159ab4 100644 --- a/src/aioslsk/peer.py +++ b/src/aioslsk/peer.py @@ -1,78 +1,37 @@ -import asyncio -from dataclasses import dataclass -from functools import partial import logging -from typing import List, Union, Tuple -from .network.connection import ( - CloseReason, - ConnectionState, - PeerConnection, - PeerConnectionType, - ServerConnection, -) +from .network.connection import PeerConnection from .events import ( on_message, build_message_map, EventBus, InternalEventBus, - ConnectionStateChangedEvent, - LoginSuccessEvent, - PeerInitializedEvent, MessageReceivedEvent, UserDirectoryEvent, UserInfoEvent, UserSharesReplyEvent, - SearchResultEvent, ) from .protocol.messages import ( - AcceptChildren, - BranchLevel, - BranchRoot, - ToggleParentSearch, - DistributedChildDepth, - DistributedBranchLevel, - DistributedBranchRoot, - DistributedSearchRequest, - DistributedServerSearchRequest, - MessageDataclass, PeerDirectoryContentsRequest, PeerDirectoryContentsReply, - PeerSearchReply, PeerSharesRequest, PeerSharesReply, PeerUserInfoReply, PeerUserInfoRequest, - PeerUploadQueueNotification, - PotentialParents, - ServerSearchRequest, ) from .network.network import Network -from .search import ReceivedSearch, SearchResult from .settings import Settings from .shares.manager import SharesManager -from .shares.utils import convert_items_to_file_data from .state import State from .transfer.manager import TransferManager -from .utils import task_counter, ticket_generator +from .utils import ticket_generator logger = logging.getLogger(__name__) -@dataclass -class DistributedPeer: - username: str - connection: PeerConnection - branch_level: int = None - branch_root: str = None - child_depth: int = None - - class PeerManager: - """Peer manager is responsible for handling peer messages and the - distributed network - """ + """Peer manager is responsible for handling peer messages""" def __init__( self, state: State, settings: Settings, @@ -89,38 +48,34 @@ def __init__( self._ticket_generator = ticket_generator() - self.parent: DistributedPeer = None - """Distributed parent. This variable is `None` if we are looking for - parents - """ - self.children: List[DistributedPeer] = [] - self.potential_parents: List[str] = [] - self.distributed_peers: List[DistributedPeer] = [] - - self._internal_event_bus.register( - PeerInitializedEvent, self._on_peer_connection_initialized) - self._internal_event_bus.register( - ConnectionStateChangedEvent, self._on_state_changed) self._internal_event_bus.register( MessageReceivedEvent, self._on_message_received) - self._internal_event_bus.register( - LoginSuccessEvent, self._on_login_success) self.MESSAGE_MAP = build_message_map(self) - self._potential_parent_tasks: List[asyncio.Task] = [] - self._search_reply_tasks: List[asyncio.Task] = [] - # External methods async def get_user_info(self, username: str): + """Requests user info from the peer itself + + :param username: name of the peer + """ await self._network.send_peer_messages( username, PeerUserInfoRequest.Request()) async def get_user_shares(self, username: str): + """Requests the shares of the peer + + :param username: name of the peer + """ await self._network.send_peer_messages( username, PeerSharesRequest.Request()) async def get_user_directory(self, username: str, directory: str) -> int: + """Requests details for a single directory from a peer + + :param username: name of the peer + :param directory: directory to request details for + """ ticket = next(self._ticket_generator) await self._network.send_peer_messages( username, @@ -131,271 +86,12 @@ async def get_user_directory(self, username: str, directory: str) -> int: ) return ticket - def _get_advertised_branch_values(self) -> Tuple[str, int]: - """Returns the advertised branch values. These values are to be sent to - the children and the server to let them know where we are in the - distributed tree. - - If no parent: - * level = 0 - * root = our own username - - If we are the root: - * level = 0 - * root = our own username - - If we have a parent: - * level = level parent advertised + 1 - * root = whatever our parent sent us initially - """ - username = self._settings.get('credentials.username') - if self.parent: - # We are the branch root - if self.parent.branch_root == username: - return username, 0 - else: - return self.parent.branch_root, self.parent.branch_level + 1 - - return username, 0 - - def get_distributed_peer(self, username: str, connection: PeerConnection) -> DistributedPeer: - for peer in self.distributed_peers: - if peer.username == username and peer.connection == connection: - return peer - - async def _set_parent(self, peer: DistributedPeer): - logger.info(f"set parent : {peer}") - self.parent = peer - - self._cancel_potential_parent_tasks() - # Cancel all tasks related to potential parents and disconnect all other - # distributed connections except for children and the parent connection - # Other distributed connection from the parent that we have should also - # be disconnected - distributed_connections = [ - distributed_peer.connection for distributed_peer in self.distributed_peers - if distributed_peer in [self.parent, ] + self.children - ] - for peer_connection in self._network.peer_connections: - if peer_connection.connection_type == PeerConnectionType.DISTRIBUTED: - if peer_connection not in distributed_connections: - asyncio.create_task( - peer_connection.disconnect(reason=CloseReason.REQUESTED), - name=f'disconnect-distributed-{task_counter()}' - ) - - await self._notify_server_of_parent() - await self._notify_children_of_branch_values() - - async def _check_if_new_parent(self, peer: DistributedPeer): - """Called after BranchRoot or BranchLevel, checks if all information is - complete for this peer/connection to become a parent and makes it a - parent if we don't have one, otherwise just close the connection. - """ - # Explicit None checks because we can get 0 as branch level - if peer.branch_level is not None and peer.branch_root is not None: - if not self.parent: - await self._set_parent(peer) - else: - await peer.connection.disconnect(reason=CloseReason.REQUESTED) - else: - logger.debug(f"{self._settings.get('credentials.username')} : not enough info for parent : {peer}") - - async def _unset_parent(self): - logger.debug(f"unset parent {self.parent!r}") - - self.parent = None - - username = self._settings.get('credentials.username') - await self._notify_server_of_parent() - - # TODO: What happens to the children when we lose our parent is still - # unclear - await self.send_messages_to_children( - DistributedBranchLevel.Request(0), - DistributedBranchRoot.Request(username) - ) - - async def _notify_server_of_parent(self): - """Notifies the server of our parent or if we don't have any, notify the - server that we are looking for one - """ - root, level = self._get_advertised_branch_values() - logger.info(f"notifying server of our parent : level={level} root={root}") - - messages = [ - BranchLevel.Request(level), - BranchRoot.Request(root) - ] - - if self.parent: - logger.info("notifying server we are not looking for parent") - messages.extend([ - ToggleParentSearch.Request(False), - AcceptChildren.Request(True) - ]) - else: - logger.info("notifying server we are looking for parent") - # The original Windows client sends out the child depth (=0) and the - # ParentIP - messages.extend([ - ToggleParentSearch.Request(True), - AcceptChildren.Request(True) - ]) - - await self._network.send_server_messages(*messages) - - async def _notify_children_of_branch_values(self): - root, level = self._get_advertised_branch_values() - await self.send_messages_to_children( - DistributedBranchLevel.Request(level), - DistributedBranchRoot.Request(root) - ) - - async def _check_if_new_child(self, peer: DistributedPeer): - """Potentially adds a distributed connection to our list of children. - """ - if peer.username in self.potential_parents: - return - - await self._add_child(peer) - - async def _add_child(self, peer: DistributedPeer): - logger.debug(f"adding distributed connection as child : {peer!r}") - self.children.append(peer) - # Let the child know where we are in the distributed tree - root, level = self._get_advertised_branch_values() - await peer.connection.send_message(DistributedBranchLevel.Request(level)) - await peer.connection.send_message(DistributedBranchRoot.Request(root)) - - def _search_reply_task_callback(self, ticket: int, username: str, query: str, task: asyncio.Task): - """Callback for a search reply task. This callback simply logs the - results and removes the task from the list - """ - try: - task.result() - - except asyncio.CancelledError: - logger.debug( - f"cancelled delivery of search results (ticket={ticket}, username={username}, query={query})") - except Exception as exc: - logger.warning( - f"failed to deliver search results : {exc!r} (ticket={ticket}, username={username}, query={query})") - else: - logger.info( - f"delivered search results (ticket={ticket}, username={username}, query={query})") - finally: - self._search_reply_tasks.remove(task) - - def _potential_parent_task_callback(self, username: str, task: asyncio.Task): - """Callback for potential parent handling task. This callback simply - logs the results and removes the task from the list - """ - try: - task.result() - - except asyncio.CancelledError: - logger.debug(f"request for potential parent cancelled (username={username})") - except Exception as exc: - logger.warning(f"request for potential parent failed : {exc!r} (username={username})") - else: - logger.info(f"request for potential parent successful (username={username})") - finally: - self._potential_parent_tasks.remove(task) - - async def _query_shares_and_reply(self, ticket: int, username: str, query: str): - """Performs a query on the shares manager and reports the results to the - user - """ - visible, locked = self._shares_manager.query(query, username=username) - - self._state.received_searches.append( - ReceivedSearch( - username=username, - query=query, - matched_files=len(visible) + len(locked) - ) - ) - - if len(visible) + len(locked) == 0: - return - - logger.info(f"found {len(visible)}/{len(locked)} results for query {query!r} (username={username!r})") - - task = asyncio.create_task( - self._network.send_peer_messages( - username, - PeerSearchReply.Request( - username=self._settings.get('credentials.username'), - ticket=ticket, - results=convert_items_to_file_data(visible, use_full_path=True), - has_slots_free=self._transfer_manager.has_slots_free(), - avg_speed=int(self._transfer_manager.get_average_upload_speed()), - queue_size=self._transfer_manager.get_queue_size(), - locked_results=convert_items_to_file_data(locked, use_full_path=True) - ) - ), - name=f'search-reply-{task_counter()}' - ) - task.add_done_callback( - partial(self._search_reply_task_callback, ticket, username, query)) - self._search_reply_tasks.append(task) - - # Server messages - - @on_message(PotentialParents.Response) - async def _on_potential_parents(self, message: PotentialParents.Response, connection: ServerConnection): - if not self._settings.get('debug.search_for_parent'): - logger.debug("ignoring PotentialParents message : searching for parent is disabled") - return - - self.potential_parents = [ - entry.username for entry in message.entries - ] - - for entry in message.entries: - task = asyncio.create_task( - self._network.create_peer_connection( - entry.username, - PeerConnectionType.DISTRIBUTED, - ip=entry.ip, - port=entry.port - ), - name=f'potential-parent-{task_counter()}' - ) - task.add_done_callback( - partial(self._potential_parent_task_callback, entry.username) - ) - self._potential_parent_tasks.append(task) - - @on_message(ServerSearchRequest.Response) - async def _on_server_search_request(self, message: ServerSearchRequest.Response, connection): - username = self._settings.get('credentials.username') - if message.username == username: - return - - if not self.parent: - # Set ourself as parent - parent = DistributedPeer( - username, - None, - branch_root=username, - branch_level=0 - ) - await self._set_parent(parent) - - await self._query_shares_and_reply( - message.ticket, message.username, message.query) - - for child in self.children: - child.connection.queue_messages(message) - # Peer messages @on_message(PeerSharesRequest.Request) async def _on_peer_shares_request(self, message: PeerSharesRequest.Request, connection: PeerConnection): visible, locked = self._shares_manager.create_shares_reply(connection.username) - connection.queue_message( + await connection.send_message( PeerSharesReply.Request( directories=visible, locked_directories=locked @@ -416,7 +112,7 @@ async def _on_peer_shares_reply(self, message: PeerSharesReply.Request, connecti @on_message(PeerDirectoryContentsRequest.Request) async def _on_peer_directory_contents_req(self, message: PeerDirectoryContentsRequest.Request, connection: PeerConnection): directories = self._shares_manager.create_directory_reply(message.directory) - connection.queue_message( + await connection.send_message( PeerDirectoryContentsReply.Request( ticket=message.ticket, directory=message.directory, @@ -431,34 +127,6 @@ async def _on_peer_directory_contents_reply(self, message: PeerDirectoryContents UserDirectoryEvent(user, message.directory, message.directories) ) - @on_message(PeerSearchReply.Request) - async def _on_peer_search_reply(self, message: PeerSearchReply.Request, connection: PeerConnection): - search_result = SearchResult( - ticket=message.ticket, - username=message.username, - has_free_slots=message.has_slots_free, - avg_speed=message.avg_speed, - queue_size=message.queue_size, - shared_items=message.results, - locked_results=message.locked_results - ) - try: - query = self._state.search_requests[message.ticket] - except KeyError: - logger.warning(f"search reply ticket does not match any search query : {message.ticket}") - else: - query.results.append(search_result) - await self._event_bus.emit(SearchResultEvent(query, search_result)) - - await connection.disconnect(reason=CloseReason.REQUESTED) - - # Update the user info - user = self._state.get_or_create_user(message.username) - user.avg_speed = message.avg_speed - user.queue_length = message.queue_size - user.has_slots_free = message.has_slots_free - await self._event_bus.emit(UserInfoEvent(user)) - @on_message(PeerUserInfoReply.Request) async def _on_peer_user_info_reply(self, message: PeerUserInfoReply.Request, connection: PeerConnection): user = self._state.get_or_create_user(connection.username) @@ -487,159 +155,13 @@ async def _on_peer_user_info_request(self, message: PeerUserInfoRequest.Request, description=description, has_picture=bool(picture), picture=picture, - upload_slots=self._transfer_manager.upload_slots, + upload_slots=self._transfer_manager.get_upload_slots(), queue_size=self._transfer_manager.get_queue_size(), has_slots_free=self._transfer_manager.has_slots_free() ) ) - @on_message(PeerUploadQueueNotification.Request) - async def _on_peer_upload_queue_notification(self, message: PeerUploadQueueNotification.Request, connection: PeerConnection): - logger.info("PeerUploadQueueNotification") - await connection.send_message( - PeerUploadQueueNotification.Request(), - ) - - # Distributed messages - - @on_message(DistributedBranchLevel.Request) - async def _on_distributed_branch_level(self, message: DistributedBranchLevel.Request, connection: PeerConnection): - logger.info(f"branch level {message.level!r}: {connection!r}") - - peer = self.get_distributed_peer(connection.username, connection) - peer.branch_level = message.level - - # Branch root is not always sent in case the peer advertises branch - # level 0 because he himself is the root - if message.level == 0: - peer.branch_root = peer.username - - if peer != self.parent: - await self._check_if_new_parent(peer) - else: - logger.info(f"parent advertised new branch level : {message.level}") - await self._notify_children_of_branch_values() - - @on_message(DistributedBranchRoot.Request) - async def _on_distributed_branch_root(self, message: DistributedBranchRoot.Request, connection: PeerConnection): - logger.info(f"branch root {message.username!r}: {connection!r}") - - peer = self.get_distributed_peer(connection.username, connection) - - # When we receive branch level 0 we automatically assume the root is the - # peer who sent the sender - # Don't do anything if the branch root is what we expected it to be - if peer.branch_root == message.username: - logger.debug(f"{self._settings.get('credentials.username')} : skipping parent check") - return - - peer.branch_root = message.username - if peer != self.parent: - await self._check_if_new_parent(peer) - else: - logger.info(f"parent advertised new branch root : {message.username}") - await self._notify_children_of_branch_values() - - @on_message(DistributedChildDepth) - async def _on_distributed_child_depth(self, message: DistributedChildDepth.Request, connection: PeerConnection): - peer = self.get_distributed_peer(connection.username, connection) - peer.child_depth = message.depth - - @on_message(DistributedSearchRequest.Request) - async def _on_distributed_search_request(self, message: DistributedSearchRequest.Request, connection: PeerConnection): - await self._query_shares_and_reply(message.ticket, message.username, message.query) - - await self.send_messages_to_children(message) - - @on_message(DistributedServerSearchRequest.Request) - async def _on_distributed_server_search_request(self, message: DistributedServerSearchRequest.Request, connection: PeerConnection): - if message.distributed_code != DistributedSearchRequest.Request.MESSAGE_ID: - logger.warning(f"no handling for server search request with code {message.distributed_code}") - return - - await self._query_shares_and_reply(message.ticket, message.username, message.query) - - dmessage = DistributedSearchRequest.Request( - unknown=0x31, - username=message.username, - ticket=message.ticket, - query=message.query - ) - await self.send_messages_to_children(dmessage) - - async def _on_peer_connection_initialized(self, event: PeerInitializedEvent): - if event.connection.connection_type == PeerConnectionType.DISTRIBUTED: - peer = DistributedPeer(event.connection.username, event.connection) - self.distributed_peers.append(peer) - - # Only check if the peer is a potential child if the connection - # was not requested by us - if not event.requested: - await self._check_if_new_child(peer) - async def _on_message_received(self, event: MessageReceivedEvent): message = event.message if message.__class__ in self.MESSAGE_MAP: await self.MESSAGE_MAP[message.__class__](message, event.connection) - - async def _on_state_changed(self, event: ConnectionStateChangedEvent): - if not isinstance(event.connection, PeerConnection): - return - - if event.state == ConnectionState.CLOSED: - if event.connection.connection_type == PeerConnectionType.DISTRIBUTED: - # Check if it was the parent that was disconnected - parent = self.parent - if parent and event.connection == parent.connection: - await self._unset_parent() - return - - # Check if it was a child - new_children = [] - for child in self.children: - if child.connection == event.connection: - logger.debug(f"removing child {child!r}") - else: - new_children.append(child) - self.children = new_children - - # Remove from the distributed connections - self.distributed_peers = [ - peer for peer in self.distributed_peers - if peer.connection != event.connection - ] - - async def _on_login_success(self, event: LoginSuccessEvent): - await self._notify_server_of_parent() - - async def send_messages_to_children(self, *messages: Union[MessageDataclass, bytes]): - for child in self.children: - child.connection.queue_messages(*messages) - - def stop(self) -> List[asyncio.Task]: - """Cancels all pending tasks - - :return: a list of tasks that have been cancelled so that they can be - awaited - """ - search_tasks = self._cancel_search_reply_tasks() - potential_parent_tasks = self._cancel_potential_parent_tasks() - return search_tasks + potential_parent_tasks - - def _cancel_search_reply_tasks(self) -> List[asyncio.Task]: - cancelled_tasks = [] - - for task in self._search_reply_tasks: - task.cancel() - cancelled_tasks.append(task) - - return cancelled_tasks - - def _cancel_potential_parent_tasks(self) -> List[asyncio.Task]: - cancelled_tasks = [] - - for task in self._potential_parent_tasks: - task.cancel() - cancelled_tasks.append(task) - - return cancelled_tasks diff --git a/src/aioslsk/search/__init__.py b/src/aioslsk/search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aioslsk/search/manager.py b/src/aioslsk/search/manager.py new file mode 100644 index 00000000..e91184cf --- /dev/null +++ b/src/aioslsk/search/manager.py @@ -0,0 +1,328 @@ +import asyncio +from collections import deque +from functools import partial +import logging +from typing import Deque, Dict, List, Optional + +from ..network.connection import ( + CloseReason, + ConnectionState, + PeerConnection, + ServerConnection, +) +from ..events import ( + on_message, + build_message_map, + EventBus, + InternalEventBus, + ConnectionStateChangedEvent, + MessageReceivedEvent, + UserDirectoryEvent, + UserInfoEvent, + UserSharesReplyEvent, + SearchResultEvent, +) +from ..protocol.messages import ( + ChatRoomSearch, + DistributedSearchRequest, + DistributedServerSearchRequest, + FileSearch, + MessageDataclass, + PeerDirectoryContentsRequest, + PeerDirectoryContentsReply, + PeerSearchReply, + PeerSharesRequest, + PeerSharesReply, + PeerUserInfoReply, + PeerUserInfoRequest, + ServerSearchRequest, + UserSearch, + WishlistInterval, + WishlistSearch, +) +from ..network.network import Network +from ..settings import Settings +from ..shares.manager import SharesManager +from ..shares.utils import convert_items_to_file_data +from ..state import State +from ..transfer.manager import TransferManager +from ..utils import task_counter, ticket_generator +from .model import ReceivedSearch, SearchResult, SearchRequest, SearchType + + +logger = logging.getLogger(__name__) + + +class SearchManager: + """Handler for searches requests""" + + def __init__( + self, state: State, settings: Settings, + event_bus: EventBus, internal_event_bus: InternalEventBus, + shares_manager: SharesManager, transfer_manager: TransferManager, + network: Network): + self._state: State = state + self._settings: Settings = settings + self._event_bus: EventBus = event_bus + self._internal_event_bus: InternalEventBus = internal_event_bus + self._network: Network = network + self._shares_manager: SharesManager = shares_manager + self._transfer_manager: TransferManager = transfer_manager + + self._ticket_generator = ticket_generator() + + self.received_searches: Deque[ReceivedSearch] = deque(list(), 500) + self.search_requests: Dict[int, SearchRequest] = {} + + self._internal_event_bus.register( + ConnectionStateChangedEvent, self._on_state_changed) + self._internal_event_bus.register( + MessageReceivedEvent, self._on_message_received) + + self.MESSAGE_MAP = build_message_map(self) + + self._search_reply_tasks: List[asyncio.Task] = [] + self._wishlist_task: asyncio.Task = None + + async def search_room(self, room: str, query: str) -> SearchRequest: + """Performs a search query on all users in a room + + :param room: name of the room to query + :param query: search query + """ + ticket = next(self._ticket_generator) + + await self._network.send_server_messages( + ChatRoomSearch.Request(room, ticket, query) + ) + self.search_requests[ticket] = SearchRequest( + ticket=ticket, + query=query, + search_type=SearchType.ROOM, + room=room + ) + return self.search_requests[ticket] + + async def search_user(self, username: str, query: str) -> SearchRequest: + """Performs a search query on a user + + :param username: username of the user to query + :param query: search query + """ + ticket = next(self._ticket_generator) + + await self._network.send_server_messages( + UserSearch.Request(username, ticket, query) + ) + self.search_requests[ticket] = SearchRequest( + ticket=ticket, + query=query, + search_type=SearchType.USER, + username=username + ) + return self.search_requests[ticket] + + async def search(self, query: str) -> SearchRequest: + """Performs a global search query + + :param query: search query + """ + ticket = next(self._ticket_generator) + + await self._network.send_server_messages( + FileSearch.Request(ticket, query) + ) + self.search_requests[ticket] = SearchRequest( + ticket=ticket, + query=query, + search_type=SearchType.NETWORK + ) + return self.search_requests[ticket] + + async def _query_shares_and_reply(self, ticket: int, username: str, query: str): + """Performs a query on the shares manager and reports the results to the + user + """ + visible, locked = self._shares_manager.query(query, username=username) + + self.received_searches.append( + ReceivedSearch( + username=username, + query=query, + matched_files=len(visible) + len(locked) + ) + ) + + if len(visible) + len(locked) == 0: + return + + logger.info(f"found {len(visible)}/{len(locked)} results for query {query!r} (username={username!r})") + + task = asyncio.create_task( + self._network.send_peer_messages( + username, + PeerSearchReply.Request( + username=self._settings.get('credentials.username'), + ticket=ticket, + results=convert_items_to_file_data(visible, use_full_path=True), + has_slots_free=self._transfer_manager.has_slots_free(), + avg_speed=int(self._transfer_manager.get_average_upload_speed()), + queue_size=self._transfer_manager.get_queue_size(), + locked_results=convert_items_to_file_data(locked, use_full_path=True) + ) + ), + name=f'search-reply-{task_counter()}' + ) + task.add_done_callback( + partial(self._search_reply_task_callback, ticket, username, query)) + self._search_reply_tasks.append(task) + + def _search_reply_task_callback(self, ticket: int, username: str, query: str, task: asyncio.Task): + """Callback for a search reply task. This callback simply logs the + results and removes the task from the list + """ + try: + task.result() + + except asyncio.CancelledError: + logger.debug( + f"cancelled delivery of search results (ticket={ticket}, username={username}, query={query})") + except Exception as exc: + logger.warning( + f"failed to deliver search results : {exc!r} (ticket={ticket}, username={username}, query={query})") + else: + logger.info( + f"delivered search results (ticket={ticket}, username={username}, query={query})") + finally: + self._search_reply_tasks.remove(task) + + async def _wishlist_job(self, interval: int): + """Job handling wishlist queries, this method is intended to be run as + a task. This method will run at the given `interval` (returned by the + server on start up). + """ + while True: + items = self._settings.get('search.wishlist') + + # Remove all current wishlist searches + self.search_requests = { + ticket: qry for ticket, qry in self.search_requests.items() + if qry.search_type != SearchType.WISHLIST + } + + logger.info(f"starting wishlist search of {len(items)} items") + # Recreate + for item in items: + if not item['enabled']: + continue + + ticket = next(self._ticket_generator) + self.search_requests[ticket] = SearchRequest( + ticket, + item['query'], + search_type=SearchType.WISHLIST + ) + self._network.queue_server_messages( + WishlistSearch.Request(ticket, item['query']) + ) + + await asyncio.sleep(interval) + + async def _on_message_received(self, event: MessageReceivedEvent): + message = event.message + if message.__class__ in self.MESSAGE_MAP: + await self.MESSAGE_MAP[message.__class__](message, event.connection) + + @on_message(DistributedSearchRequest.Request) + async def _on_distributed_search_request( + self, message: DistributedSearchRequest.Request, connection: PeerConnection): + + await self._query_shares_and_reply(message.ticket, message.username, message.query) + + @on_message(DistributedServerSearchRequest.Request) + async def _on_distributed_server_search_request( + self, message: DistributedServerSearchRequest.Request, connection: PeerConnection): + + if message.distributed_code != DistributedSearchRequest.Request.MESSAGE_ID: + logger.warning(f"no handling for server search request with code {message.distributed_code}") + return + + await self._query_shares_and_reply(message.ticket, message.username, message.query) + + @on_message(ServerSearchRequest.Response) + async def _on_server_search_request(self, message: ServerSearchRequest.Response, connection): + username = self._settings.get('credentials.username') + if message.username == username: + return + + await self._query_shares_and_reply( + message.ticket, message.username, message.query) + + @on_message(PeerSearchReply.Request) + async def _on_peer_search_reply(self, message: PeerSearchReply.Request, connection: PeerConnection): + search_result = SearchResult( + ticket=message.ticket, + username=message.username, + has_free_slots=message.has_slots_free, + avg_speed=message.avg_speed, + queue_size=message.queue_size, + shared_items=message.results, + locked_results=message.locked_results + ) + try: + query = self.search_requests[message.ticket] + except KeyError: + logger.warning(f"search reply ticket does not match any search query : {message.ticket}") + else: + query.results.append(search_result) + await self._event_bus.emit(SearchResultEvent(query, search_result)) + + await connection.disconnect(reason=CloseReason.REQUESTED) + + # Update the user info + user = self._state.get_or_create_user(message.username) + user.avg_speed = message.avg_speed + user.queue_length = message.queue_size + user.has_slots_free = message.has_slots_free + await self._event_bus.emit(UserInfoEvent(user)) + + @on_message(WishlistInterval.Response) + async def _on_wish_list_interval(self, message: WishlistInterval.Response, connection): + self._cancel_wishlist_task() + + self._wishlist_task = asyncio.create_task( + self._wishlist_job(message.interval), + name=f'wishlist-job-{task_counter()}' + ) + + async def _on_state_changed(self, event: ConnectionStateChangedEvent): + if not isinstance(event.connection, ServerConnection): + return + + if event.state == ConnectionState.CLOSING: + self._cancel_wishlist_task() + + def _cancel_wishlist_task(self) -> Optional[asyncio.Task]: + task = self._wishlist_task + if self._wishlist_task is not None: + self._wishlist_task.cancel() + self._wishlist_task = None + return task + return None + + def stop(self) -> List[asyncio.Task]: + """Cancels all pending tasks + + :return: a list of tasks that have been cancelled so that they can be + awaited + """ + cancelled_tasks = [] + + for task in self._search_reply_tasks: + task.cancel() + cancelled_tasks.append(task) + + if (wishlist_task := self._cancel_wishlist_task()) is not None: + cancelled_tasks.append(wishlist_task) + + return cancelled_tasks diff --git a/src/aioslsk/search.py b/src/aioslsk/search/model.py similarity index 98% rename from src/aioslsk/search.py rename to src/aioslsk/search/model.py index 06120b8e..f9c5233c 100644 --- a/src/aioslsk/search.py +++ b/src/aioslsk/search/model.py @@ -4,7 +4,7 @@ import re from typing import List -from .protocol.primitives import FileData +from ..protocol.primitives import FileData class SearchType(Enum): diff --git a/src/aioslsk/server.py b/src/aioslsk/server.py index b3f0bf5f..506017d1 100644 --- a/src/aioslsk/server.py +++ b/src/aioslsk/server.py @@ -53,7 +53,6 @@ ChatLeaveRoom, ChatPrivateMessage, ChatAckPrivateMessage, - ChatRoomSearch, ChatRoomTickers, ChatRoomTickerAdded, ChatRoomTickerRemoved, @@ -62,7 +61,6 @@ ChatUserLeftRoom, CheckPrivileges, DistributedAliveInterval, - FileSearch, GetGlobalRecommendations, GetItemRecommendations, GetItemSimilarUsers, @@ -101,14 +99,11 @@ SetListenPort, SetStatus, SharedFoldersFiles, - UserSearch, WishlistInterval, - WishlistSearch, ) from .model import ChatMessage, RoomMessage, User, UserStatus, TrackingFlag from .network.network import Network from .shares.manager import SharesManager -from .search import SearchRequest, SearchType from .settings import Settings from .state import State from .utils import task_counter, ticket_generator @@ -119,7 +114,10 @@ class ServerManager: - def __init__(self, state: State, settings: Settings, event_bus: EventBus, internal_event_bus: InternalEventBus, shares_manager: SharesManager, network: Network): + def __init__( + self, state: State, settings: Settings, + event_bus: EventBus, internal_event_bus: InternalEventBus, + shares_manager: SharesManager, network: Network): self._state: State = state self._settings: Settings = settings self._event_bus: EventBus = event_bus @@ -130,7 +128,6 @@ def __init__(self, state: State, settings: Settings, event_bus: EventBus, intern self._ticket_generator = ticket_generator() self._ping_task: asyncio.Task = None - self._wishlist_task: asyncio.Task = None self._post_login_task: asyncio.Task = None self._connection_watchdog_task: asyncio.Task = None @@ -179,7 +176,10 @@ async def login(self, username: str, password: str, version: int = 157): async def track_user(self, username: str, flag: TrackingFlag): """Starts tracking a user. The method sends an `AddUser` only if the `is_tracking` variable is set to False. Updates to the user will be - omitted through the `UserInfoEvent` + emitted through the `UserInfoEvent` event + + :param user: user to track + :param flag: tracking flag to add from the user """ user = self._state.get_or_create_user(username) @@ -189,11 +189,10 @@ async def track_user(self, username: str, flag: TrackingFlag): await self._network.send_server_messages(AddUser.Request(username)) async def track_friends(self): + """Starts tracking the users defined defined in the friends list""" tasks = [] for friend in self._settings.get('users.friends'): - tasks.append( - asyncio.create_task(self.track_user(friend, TrackingFlag.FRIEND)) - ) + tasks.append(self.track_user(friend, TrackingFlag.FRIEND)) asyncio.gather(*tasks, return_exceptions=True) @@ -230,50 +229,6 @@ async def auto_join_rooms(self): *[ChatJoinRoom.Request(room) for room in rooms] ) - async def search_room(self, room: str, query: str) -> SearchRequest: - """Performs a search query on all users in a room""" - ticket = next(self._ticket_generator) - - await self._network.send_server_messages( - ChatRoomSearch.Request(room, ticket, query) - ) - self._state.search_requests[ticket] = SearchRequest( - ticket=ticket, - query=query, - search_type=SearchType.ROOM, - room=room - ) - return self._state.search_requests[ticket] - - async def search_user(self, username: str, query: str) -> SearchRequest: - """Performs a search query on a user""" - ticket = next(self._ticket_generator) - - await self._network.send_server_messages( - UserSearch.Request(username, ticket, query) - ) - self._state.search_requests[ticket] = SearchRequest( - ticket=ticket, - query=query, - search_type=SearchType.USER, - username=username - ) - return self._state.search_requests[ticket] - - async def search(self, query: str) -> SearchRequest: - """Performs a global search query""" - ticket = next(self._ticket_generator) - - await self._network.send_server_messages( - FileSearch.Request(ticket, query) - ) - self._state.search_requests[ticket] = SearchRequest( - ticket=ticket, - query=query, - search_type=SearchType.NETWORK - ) - return self._state.search_requests[ticket] - async def get_user_stats(self, username: str): # pragma: no cover await self._network.send_server_messages(GetUserStats.Request(username)) @@ -713,12 +668,6 @@ async def _on_add_privileged_user(self, message: AddPrivilegedUser.Response, con @on_message(WishlistInterval.Response) async def _on_wish_list_interval(self, message: WishlistInterval.Response, connection): self._state.wishlist_interval = message.interval - self._cancel_wishlist_task() - - self._wishlist_task = asyncio.create_task( - self._wishlist_job(message.interval), - name=f'wishlist-job-{task_counter()}' - ) @on_message(AddUser.Response) async def _on_add_user(self, message: AddUser.Response, connection): @@ -840,43 +789,6 @@ def _cancel_connection_watchdog_task(self): self._connection_watchdog_task.cancel() self._connection_watchdog_task = None - async def _wishlist_job(self, interval: int): - """Job handling wishlist queries, this method is intended to be run as - a task. This method will run at the given `interval` (returned by the - server on start up). - """ - while True: - items = self._settings.get('search.wishlist') - - # Remove all current wishlist searches - self._state.search_requests = { - ticket: qry for ticket, qry in self._state.search_requests.items() - if qry.search_type != SearchType.WISHLIST - } - - logger.info(f"starting wishlist search of {len(items)} items") - # Recreate - for item in items: - if not item['enabled']: - continue - - ticket = next(self._ticket_generator) - self._state.search_requests[ticket] = SearchRequest( - ticket, - item['query'], - search_type=SearchType.WISHLIST - ) - self._network.queue_server_messages( - WishlistSearch.Request(ticket, item['query']) - ) - - await asyncio.sleep(interval) - - def _cancel_wishlist_task(self): - if self._wishlist_task is not None: - self._wishlist_task.cancel() - self._wishlist_task = None - # Listeners async def _on_track_user(self, event: TrackUserEvent): @@ -912,7 +824,6 @@ async def _on_state_changed(self, event: ConnectionStateChangedEvent): elif event.state == ConnectionState.CLOSING: - self._cancel_wishlist_task() self._cancel_ping_task() # When `disconnect` is called on the connection it will always first # go into the CLOSING state. The watchdog will only attempt to diff --git a/src/aioslsk/shares/manager.py b/src/aioslsk/shares/manager.py index d65bdc9e..08483251 100644 --- a/src/aioslsk/shares/manager.py +++ b/src/aioslsk/shares/manager.py @@ -26,7 +26,7 @@ NumberDuplicateStrategy, ) from ..protocol.primitives import DirectoryData -from ..search import SearchQuery +from ..search.model import SearchQuery from ..settings import Settings from .utils import create_term_pattern, convert_items_to_file_data diff --git a/src/aioslsk/state.py b/src/aioslsk/state.py index b6bc2e32..89a03945 100644 --- a/src/aioslsk/state.py +++ b/src/aioslsk/state.py @@ -1,10 +1,8 @@ from __future__ import annotations -from collections import deque from dataclasses import dataclass, field -from typing import Deque, Dict, List, Union +from typing import Dict, List, Union from .model import Room, User, UserStatus, TrackingFlag -from .search import ReceivedSearch, SearchRequest @dataclass @@ -24,10 +22,6 @@ class State: distributed_alive_interval: int = 0 wishlist_interval: int = 0 - # Search related - received_searches: Deque[ReceivedSearch] = field(default_factory=lambda: deque(list(), 500)) - search_requests: Dict[int, SearchRequest] = field(default_factory=dict) - def get_joined_rooms(self) -> List[Room]: return [room for room in self.rooms.values() if room.joined] diff --git a/src/aioslsk/transfer/manager.py b/src/aioslsk/transfer/manager.py index a8b0660c..432f9744 100644 --- a/src/aioslsk/transfer/manager.py +++ b/src/aioslsk/transfer/manager.py @@ -119,7 +119,7 @@ async def read_cache(self) -> List[Transfer]: await self._add_transfer(transfer) def write_cache(self): - """Write all current transfers back to the """ + """Write all current transfers back to the cache""" self.cache.write(self._transfers) def stop(self) -> List[asyncio.Task]: @@ -137,10 +137,6 @@ def stop(self) -> List[asyncio.Task]: def transfers(self): return self._transfers - @property - def upload_slots(self): - return self._settings.get('sharing.limits.upload_slots') - async def abort(self, transfer: Transfer): """Aborts the given transfer. This will cancel all pending transfers and remove the file (in case of download) @@ -208,6 +204,10 @@ def get_uploads(self) -> List[Transfer]: def get_downloads(self) -> List[Transfer]: return [transfer for transfer in self._transfers if transfer.is_download()] + def get_upload_slots(self) -> int: + """Returns the total amount of upload slots""" + return self._settings.get('sharing.limits.upload_slots') + def has_slots_free(self) -> bool: return self.get_free_upload_slots() > 0 @@ -217,7 +217,7 @@ def get_free_upload_slots(self) -> int: if transfer.is_upload() and transfer.is_processing(): uploading_transfers.append(transfer) - available_slots = self.upload_slots - len(uploading_transfers) + available_slots = self.get_upload_slots() - len(uploading_transfers) return max(0, available_slots) def get_queue_size(self) -> int: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/search/test_search_manager.py b/tests/unit/search/test_search_manager.py new file mode 100644 index 00000000..e5663c9a --- /dev/null +++ b/tests/unit/search/test_search_manager.py @@ -0,0 +1,131 @@ +from aioslsk.events import SearchResultEvent, UserInfoEvent +from aioslsk.protocol.primitives import FileData +from aioslsk.protocol.messages import ( + UserSearch, + FileSearch, + ChatRoomSearch, + PeerSearchReply, +) +from aioslsk.state import State +from aioslsk.search.manager import SearchManager +from aioslsk.search.model import SearchType, SearchRequest +from aioslsk.settings import Settings +import pytest +from unittest.mock import AsyncMock, Mock, call + + +DEFAULT_USER = 'testuser0' +DEFAULT_SETTINGS = { + 'credentials': { + 'username': DEFAULT_USER + } +} + + +@pytest.fixture +def manager() -> SearchManager: + network = Mock() + network.send_server_messages = AsyncMock() + event_bus = Mock() + event_bus.emit = AsyncMock() + event_bus.register = Mock() + internal_event_bus = Mock() + internal_event_bus.emit = AsyncMock() + internal_event_bus.register = Mock() + shares_manager = Mock() + transfer_manager = Mock() + + return SearchManager( + State(), + Settings(DEFAULT_SETTINGS), + event_bus, + internal_event_bus, + shares_manager, + transfer_manager, + network + ) + + + +class TestSearchManager: + + @pytest.mark.asyncio + async def test_search(self, manager: SearchManager): + request = await manager.search('query') + + assert request.search_type == SearchType.NETWORK + assert request.ticket is not None + assert request.query == 'query' + assert request.username is None + assert request.room is None + + assert request.ticket in manager.search_requests + + manager._network.send_server_messages.assert_awaited_once_with( + FileSearch.Request(request.ticket, 'query') + ) + + @pytest.mark.asyncio + async def test_searchUser(self, manager: SearchManager): + request = await manager.search_user('user0', 'query') + + assert request.search_type == SearchType.USER + assert request.ticket is not None + assert request.query == 'query' + assert request.username == 'user0' + assert request.room is None + + assert request.ticket in manager.search_requests + + manager._network.send_server_messages.assert_awaited_once_with( + UserSearch.Request('user0', request.ticket, 'query') + ) + + @pytest.mark.asyncio + async def test_searchRoom(self, manager: SearchManager): + request = await manager.search_room('room0', 'query') + + assert request.search_type == SearchType.ROOM + assert request.ticket is not None + assert request.query == 'query' + assert request.username is None + assert request.room == 'room0' + + assert request.ticket in manager.search_requests + + manager._network.send_server_messages.assert_awaited_once_with( + ChatRoomSearch.Request('room0', request.ticket, 'query') + ) + + @pytest.mark.asyncio + async def test_onPeerSearchReply_shouldStoreResultsAndEmit(self, manager: SearchManager): + TICKET = 1234 + connection = AsyncMock() + + manager.search_requests[TICKET] = SearchRequest( + TICKET, 'search', SearchType.NETWORK) + + reply_message = PeerSearchReply.Request( + 'user0', + TICKET, + results=[FileData(1, 'myfile.mp3', 10000, 'mp3', attributes=[])], + has_slots_free=True, + avg_speed=100, + queue_size=2, + locked_results=[FileData(1, 'locked.mp3', 10000, 'mp3', attributes=[])] + ) + await manager._on_peer_search_reply(reply_message, connection) + + assert 1 == len(manager.search_requests[TICKET].results) + + manager._event_bus.emit.assert_has_awaits( + [ + call( + SearchResultEvent( + manager.search_requests[TICKET], + manager.search_requests[TICKET].results[0] + ) + ), + call(UserInfoEvent(manager._state.get_or_create_user('user0'))) + ] + ) diff --git a/tests/unit/test_peer.py b/tests/unit/test_peer.py index 2b3bae0a..eafb9522 100644 --- a/tests/unit/test_peer.py +++ b/tests/unit/test_peer.py @@ -1,19 +1,17 @@ -from aioslsk.events import UserDirectoryEvent, UserInfoEvent, SearchResultEvent -from aioslsk.protocol.primitives import DirectoryData, FileData +from aioslsk.events import UserDirectoryEvent +from aioslsk.protocol.primitives import DirectoryData from aioslsk.protocol.messages import ( PeerDirectoryContentsRequest, PeerDirectoryContentsReply, PeerUserInfoReply, PeerUserInfoRequest, - PeerSearchReply, ) from aioslsk.peer import PeerManager -from aioslsk.search import SearchRequest, SearchType from aioslsk.settings import Settings from aioslsk.state import State import pytest -from unittest.mock import ANY, AsyncMock, call, Mock, PropertyMock +from unittest.mock import ANY, AsyncMock, Mock USER_DESCRIPTION = 'describes the user' @@ -71,7 +69,7 @@ def _create_peer_manager(self, settings: dict = DEFAULT_SETTINGS) -> PeerManager async def test_onPeerInfoRequest_withInfo_shouldSendPeerInfoReply(self): manager = self._create_peer_manager(SETTINGS_WITH_INFO) connection = AsyncMock() - type(manager._transfer_manager).upload_slots = PropertyMock(return_value=UPLOAD_SLOTS) + manager._transfer_manager.get_upload_slots = Mock(return_value=UPLOAD_SLOTS) manager._transfer_manager.get_queue_size = Mock(return_value=QUEUE_SIZE) manager._transfer_manager.has_slots_free = Mock(return_value=HAS_SLOTS_FREE) @@ -91,7 +89,7 @@ async def test_onPeerInfoRequest_withInfo_shouldSendPeerInfoReply(self): async def test_onPeerInfoRequest_withoutInfo_shouldSendPeerInfoReply(self): manager = self._create_peer_manager(DEFAULT_SETTINGS) connection = AsyncMock() - type(manager._transfer_manager).upload_slots = PropertyMock(return_value=UPLOAD_SLOTS) + manager._transfer_manager.get_upload_slots = Mock(return_value=UPLOAD_SLOTS) manager._transfer_manager.get_queue_size = Mock(return_value=QUEUE_SIZE) manager._transfer_manager.has_slots_free = Mock(return_value=HAS_SLOTS_FREE) @@ -126,7 +124,7 @@ async def test_whenDirectoryRequestReceived_shouldRespond(self): manager = self._create_peer_manager() manager._shares_manager.create_directory_reply.return_value = DIRECTORY_DATA - connection = Mock() + connection = AsyncMock() connection.username = USER await manager._on_peer_directory_contents_req( @@ -134,7 +132,7 @@ async def test_whenDirectoryRequestReceived_shouldRespond(self): ) manager._shares_manager.create_directory_reply.assert_called_once_with(DIRECTORY) - connection.queue_message.assert_called_once_with( + connection.send_message.assert_called_once_with( PeerDirectoryContentsReply.Request(TICKET, DIRECTORY, DIRECTORY_DATA) ) @@ -158,38 +156,3 @@ async def test_whenDirectoryReplyReceived_shouldEmitEvent(self): manager._event_bus.emit.assert_awaited_once_with( UserDirectoryEvent(user, DIRECTORY, DIRECTORIES) ) - - @pytest.mark.asyncio - async def test_onPeerSearchReply_shouldStoreResultsAndEmit(self): - manager = self._create_peer_manager() - TICKET = 1234 - connection = AsyncMock() - - manager._state.search_requests[TICKET] = SearchRequest( - TICKET, 'search', SearchType.NETWORK) - - reply_message = PeerSearchReply.Request( - 'user0', - TICKET, - results=[FileData(1, 'myfile.mp3', 10000, 'mp3', attributes=[])], - has_slots_free=True, - avg_speed=100, - queue_size=2, - locked_results=[FileData(1, 'locked.mp3', 10000, 'mp3', attributes=[])] - ) - await manager._on_peer_search_reply(reply_message, connection) - - assert 1 == len(manager._state.search_requests[TICKET].results) - - manager._event_bus.emit.assert_has_awaits( - [ - call( - SearchResultEvent( - manager._state.search_requests[TICKET], - manager._state.search_requests[TICKET].results[0] - ) - ), - call(UserInfoEvent(manager._state.get_or_create_user('user0'))) - ] - - ) diff --git a/tests/unit/test_server_manager.py b/tests/unit/test_server_manager.py index d0a2d0de..fb75523e 100644 --- a/tests/unit/test_server_manager.py +++ b/tests/unit/test_server_manager.py @@ -26,7 +26,7 @@ RemoveUser, ) from aioslsk.protocol.primitives import RoomTicker, UserStats -from aioslsk.search import SearchType +from aioslsk.search.model import SearchType from aioslsk.settings import Settings from aioslsk.server import ServerManager from aioslsk.state import State @@ -398,38 +398,3 @@ async def test_onUserLeftRoom_shouldRemoveUserFromRoom(self, manager: ServerMana async def test_whenSetRoomTicker_shouldSetRoomTicker(self, manager: ServerManager): await manager.set_room_ticker('room0', 'hello') manager._network.send_server_messages.assert_awaited_once() - - @pytest.mark.asyncio - async def test_searchNetwork_shouldSearchAndCreateEntry(self, manager: ServerManager): - search_query = await manager.search('my query') - assert 'my query' == search_query.query - assert isinstance(search_query.ticket, int) - assert SearchType.NETWORK == search_query.search_type - - manager._network.send_server_messages.assert_awaited_once() - - @pytest.mark.asyncio - async def test_searchRoom_shouldSearchAndCreateEntry(self, manager: ServerManager): - query = 'my query' - room_name = 'room0' - - search_query = await manager.search_room(room_name, query) - assert query == search_query.query - assert isinstance(search_query.ticket, int) - assert SearchType.ROOM == search_query.search_type - assert room_name == search_query.room - - manager._network.send_server_messages.assert_awaited_once() - - @pytest.mark.asyncio - async def test_searchUser_shouldSearchAndCreateEntry(self, manager: ServerManager): - query = 'my query' - username = 'room0' - - search_query = await manager.search_user(username, query) - assert query == search_query.query - assert isinstance(search_query.ticket, int) - assert SearchType.USER == search_query.search_type - assert username == search_query.username - - manager._network.send_server_messages.assert_awaited_once() From 6dab4823b3885e5e0b786c46ae26a6627a3e07ae Mon Sep 17 00:00:00 2001 From: JurgenR <1249228+JurgenR@users.noreply.github.com> Date: Thu, 5 Oct 2023 17:46:22 +0200 Subject: [PATCH 2/2] feat: remove unused dependencies, add event for received searches --- src/aioslsk/events.py | 8 ++++++++ src/aioslsk/search/manager.py | 20 ++++++++++---------- src/aioslsk/search/model.py | 6 ++++-- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/aioslsk/events.py b/src/aioslsk/events.py index 072dd36e..2a478be6 100644 --- a/src/aioslsk/events.py +++ b/src/aioslsk/events.py @@ -200,6 +200,14 @@ class SearchResultEvent(Event): result: SearchResult +@dataclass(frozen=True) +class SearchRequestReceivedEvent(Event): + """Emitted when a search request by another user has been received""" + username: str + query: str + result_count: int + + @dataclass(frozen=True) class SimilarUsersEvent(Event): users: List[User] diff --git a/src/aioslsk/search/manager.py b/src/aioslsk/search/manager.py index e91184cf..a5e7809b 100644 --- a/src/aioslsk/search/manager.py +++ b/src/aioslsk/search/manager.py @@ -17,9 +17,8 @@ InternalEventBus, ConnectionStateChangedEvent, MessageReceivedEvent, - UserDirectoryEvent, UserInfoEvent, - UserSharesReplyEvent, + SearchRequestReceivedEvent, SearchResultEvent, ) from ..protocol.messages import ( @@ -27,14 +26,7 @@ DistributedSearchRequest, DistributedServerSearchRequest, FileSearch, - MessageDataclass, - PeerDirectoryContentsRequest, - PeerDirectoryContentsReply, PeerSearchReply, - PeerSharesRequest, - PeerSharesReply, - PeerUserInfoReply, - PeerUserInfoRequest, ServerSearchRequest, UserSearch, WishlistInterval, @@ -145,11 +137,19 @@ async def _query_shares_and_reply(self, ticket: int, username: str, query: str): """ visible, locked = self._shares_manager.query(query, username=username) + result_count = len(visible) + len(locked) self.received_searches.append( ReceivedSearch( username=username, query=query, - matched_files=len(visible) + len(locked) + result_count=result_count + ) + ) + await self._event_bus.emit( + SearchRequestReceivedEvent( + username=username, + query=query, + result_count=result_count ) ) diff --git a/src/aioslsk/search/model.py b/src/aioslsk/search/model.py index f9c5233c..79bf758f 100644 --- a/src/aioslsk/search/model.py +++ b/src/aioslsk/search/model.py @@ -16,10 +16,12 @@ class SearchType(Enum): @dataclass class ReceivedSearch: - """Used for keeping track of searches received from the distributed parent""" + """Used for keeping track of searches received from the distributed parent + or server + """ username: str query: str - matched_files: int + result_count: int @dataclass