-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #50 from deepset-ai/improved-type-matching
feat: improved type matching
- Loading branch information
Showing
8 changed files
with
286 additions
and
277 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Tuple, Optional, List, Any, get_args | ||
from typing import Tuple, Optional, Union, List, Any, get_args, get_origin | ||
|
||
import logging | ||
import itertools | ||
|
@@ -23,26 +23,83 @@ def parse_connection_name(connection: str) -> Tuple[str, Optional[str]]: | |
return connection, None | ||
|
||
|
||
def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements | ||
""" | ||
Checks whether the source type is equal or a subtype of the destination type. Used to validate pipeline connections. | ||
Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of | ||
typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well | ||
with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same. | ||
Consider simplifying the typing of your components if you observe unexpected errors during component connection. | ||
""" | ||
if sender == receiver or receiver is Any: | ||
return True | ||
|
||
if sender is Any: | ||
return False | ||
|
||
try: | ||
if issubclass(sender, receiver): | ||
return True | ||
except TypeError: # typing classes can't be used with issubclass, so we deal with them below | ||
pass | ||
|
||
sender_origin = get_origin(sender) | ||
receiver_origin = get_origin(receiver) | ||
|
||
if sender_origin is not Union and receiver_origin is Union: | ||
return any(_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver)) | ||
|
||
if not sender_origin or not receiver_origin or sender_origin != receiver_origin: | ||
return False | ||
|
||
sender_args = get_args(sender) | ||
receiver_args = get_args(receiver) | ||
if len(sender_args) > len(receiver_args): | ||
return False | ||
|
||
return all(_types_are_compatible(*args) for args in zip(sender_args, receiver_args)) | ||
|
||
|
||
def find_unambiguous_connection( | ||
from_node: str, to_node: str, from_sockets: List[OutputSocket], to_sockets: List[InputSocket] | ||
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket] | ||
) -> Tuple[OutputSocket, InputSocket]: | ||
""" | ||
Find one single possible connection between two lists of sockets. | ||
""" | ||
# List all combinations of sockets that match by type | ||
possible_connections = [ | ||
(out_sock, in_sock) | ||
for out_sock, in_sock in itertools.product(from_sockets, to_sockets) | ||
if not in_sock.sender and (Any in in_sock.types or out_sock.types == in_sock.types) | ||
(sender_sock, receiver_sock) | ||
for sender_sock, receiver_sock in itertools.product(sender_sockets, receiver_sockets) | ||
if _types_are_compatible(sender_sock.type, receiver_sock.type) | ||
] | ||
|
||
# No connections seem to be possible | ||
if not possible_connections: | ||
connections_status_str = _connections_status( | ||
from_node=from_node, from_sockets=from_sockets, to_node=to_node, to_sockets=to_sockets | ||
sender_node=sender_node, | ||
sender_sockets=sender_sockets, | ||
receiver_node=receiver_node, | ||
receiver_sockets=receiver_sockets, | ||
) | ||
|
||
# Both sockets were specified: explain why the types don't match | ||
if len(sender_sockets) == len(receiver_sockets) and len(sender_sockets) == 1: | ||
raise PipelineConnectError( | ||
f"Cannot connect '{sender_node}.{sender_sockets[0].name}' with '{receiver_node}.{receiver_sockets[0].name}': " | ||
f"their declared input and output types do not match.\n{connections_status_str}" | ||
) | ||
|
||
# Not both sockets were specified: explain there's no possible match on any pair | ||
connections_status_str = _connections_status( | ||
sender_node=sender_node, | ||
sender_sockets=sender_sockets, | ||
receiver_node=receiver_node, | ||
receiver_sockets=receiver_sockets, | ||
) | ||
raise PipelineConnectError( | ||
f"Cannot connect '{from_node}' with '{to_node}': " | ||
f"Cannot connect '{sender_node}' with '{receiver_node}': " | ||
f"no matching connections available.\n{connections_status_str}" | ||
) | ||
|
||
|
@@ -56,40 +113,46 @@ def find_unambiguous_connection( | |
# TODO allow for multiple connections at once if there is no ambiguity? | ||
# TODO give priority to sockets that have no default values? | ||
connections_status_str = _connections_status( | ||
from_node=from_node, from_sockets=from_sockets, to_node=to_node, to_sockets=to_sockets | ||
sender_node=sender_node, | ||
sender_sockets=sender_sockets, | ||
receiver_node=receiver_node, | ||
receiver_sockets=receiver_sockets, | ||
) | ||
raise PipelineConnectError( | ||
f"Cannot connect '{from_node}' with '{to_node}': more than one connection is possible " | ||
f"Cannot connect '{sender_node}' with '{receiver_node}': more than one connection is possible " | ||
"between these components. Please specify the connection name, like: " | ||
f"pipeline.connect('{from_node}.{possible_connections[0][0].name}', " | ||
f"'{to_node}.{possible_connections[0][1].name}').\n{connections_status_str}" | ||
f"pipeline.connect('{sender_node}.{possible_connections[0][0].name}', " | ||
f"'{receiver_node}.{possible_connections[0][1].name}').\n{connections_status_str}" | ||
) | ||
|
||
return possible_connections[0] | ||
|
||
|
||
def _connections_status(from_node: str, to_node: str, from_sockets: List[OutputSocket], to_sockets: List[InputSocket]): | ||
def _connections_status( | ||
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket] | ||
): | ||
""" | ||
Lists the status of the sockets, for error messages. | ||
""" | ||
from_sockets_entries = [] | ||
for from_socket in from_sockets: | ||
socket_types = ", ".join([_get_socket_type_desc(t) for t in from_socket.types]) | ||
from_sockets_entries.append(f" - {from_socket.name} ({socket_types})") | ||
from_sockets_list = "\n".join(from_sockets_entries) | ||
|
||
to_sockets_entries = [] | ||
for to_socket in to_sockets: | ||
socket_types = ", ".join([_get_socket_type_desc(t) for t in to_socket.types]) | ||
to_sockets_entries.append( | ||
f" - {to_socket.name} ({socket_types}), {'sent by '+to_socket.sender if to_socket.sender else 'available'}" | ||
sender_sockets_entries = [] | ||
for sender_socket in sender_sockets: | ||
socket_types = get_socket_type_desc(sender_socket.type) | ||
sender_sockets_entries.append(f" - {sender_socket.name} ({socket_types})") | ||
sender_sockets_list = "\n".join(sender_sockets_entries) | ||
|
||
receiver_sockets_entries = [] | ||
for receiver_socket in receiver_sockets: | ||
socket_types = get_socket_type_desc(receiver_socket.type) | ||
receiver_sockets_entries.append( | ||
f" - {receiver_socket.name} ({socket_types}), " | ||
f"{'sent by '+receiver_socket.sender if receiver_socket.sender else 'available'}" | ||
) | ||
to_sockets_list = "\n".join(to_sockets_entries) | ||
receiver_sockets_list = "\n".join(receiver_sockets_entries) | ||
|
||
return f"'{from_node}':\n{from_sockets_list}\n'{to_node}':\n{to_sockets_list}" | ||
return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}" | ||
|
||
|
||
def _get_socket_type_desc(type_): | ||
def get_socket_type_desc(type_): | ||
""" | ||
Assembles a readable representation of the type of a connection. Can handle primitive types, classes, and | ||
arbitrarily nested structures of types from the typing module. | ||
|
@@ -126,5 +189,5 @@ def _get_socket_type_desc(type_): | |
else: | ||
type_name = type_.__name__ | ||
|
||
subtypes = ", ".join([_get_socket_type_desc(subtype) for subtype in args if subtype is not type(None)]) | ||
subtypes = ", ".join([get_socket_type_desc(subtype) for subtype in args if subtype is not type(None)]) | ||
return f"{type_name}[{subtypes}]" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Union, Optional, Dict, Set, Any, get_origin, get_args | ||
from typing import Optional, Dict | ||
|
||
import logging | ||
from dataclasses import dataclass, fields, Field | ||
from dataclasses import dataclass, fields | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -13,31 +13,25 @@ | |
@dataclass | ||
class OutputSocket: | ||
name: str | ||
types: Set[type] | ||
type: type | ||
|
||
|
||
@dataclass | ||
class InputSocket: | ||
name: str | ||
types: Set[type] | ||
type: type | ||
sender: Optional[str] = None | ||
|
||
|
||
def find_input_sockets(component) -> Dict[str, InputSocket]: | ||
""" | ||
Find a component's input sockets. | ||
""" | ||
return {f.name: InputSocket(name=f.name, types=_get_types(f)) for f in fields(component.__canals_input__)} | ||
return {f.name: InputSocket(name=f.name, type=f.type) for f in fields(component.__canals_input__)} | ||
|
||
|
||
def find_output_sockets(component) -> Dict[str, OutputSocket]: | ||
""" | ||
Find a component's output sockets. | ||
""" | ||
return {f.name: OutputSocket(name=f.name, types=_get_types(f)) for f in fields(component.__canals_output__)} | ||
|
||
|
||
def _get_types(field: Field) -> Set[Any]: | ||
if get_origin(field.type) is Union: | ||
return {t for t in get_args(field.type) if t is not type(None)} | ||
return {field.type} | ||
return {f.name: OutputSocket(name=f.name, type=f.type) for f in fields(component.__canals_output__)} |
Oops, something went wrong.