Skip to content

Commit

Permalink
Merge pull request #50 from deepset-ai/improved-type-matching
Browse files Browse the repository at this point in the history
feat: improved type matching
  • Loading branch information
ZanSara committed Jul 12, 2023
2 parents d70af9d + a880cec commit 781af48
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 277 deletions.
117 changes: 90 additions & 27 deletions canals/pipeline/connections.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Tuple, Optional, List, Any, get_args
from typing import Tuple, Optional, Union, List, Any, get_args, get_origin

import logging
import itertools
Expand All @@ -23,26 +23,83 @@ def parse_connection_name(connection: str) -> Tuple[str, Optional[str]]:
return connection, None


def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements
"""
Checks whether the source type is equal or a subtype of the destination type. Used to validate pipeline connections.
Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of
typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well
with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same.
Consider simplifying the typing of your components if you observe unexpected errors during component connection.
"""
if sender == receiver or receiver is Any:
return True

if sender is Any:
return False

try:
if issubclass(sender, receiver):
return True
except TypeError: # typing classes can't be used with issubclass, so we deal with them below
pass

sender_origin = get_origin(sender)
receiver_origin = get_origin(receiver)

if sender_origin is not Union and receiver_origin is Union:
return any(_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver))

if not sender_origin or not receiver_origin or sender_origin != receiver_origin:
return False

sender_args = get_args(sender)
receiver_args = get_args(receiver)
if len(sender_args) > len(receiver_args):
return False

return all(_types_are_compatible(*args) for args in zip(sender_args, receiver_args))


def find_unambiguous_connection(
from_node: str, to_node: str, from_sockets: List[OutputSocket], to_sockets: List[InputSocket]
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
) -> Tuple[OutputSocket, InputSocket]:
"""
Find one single possible connection between two lists of sockets.
"""
# List all combinations of sockets that match by type
possible_connections = [
(out_sock, in_sock)
for out_sock, in_sock in itertools.product(from_sockets, to_sockets)
if not in_sock.sender and (Any in in_sock.types or out_sock.types == in_sock.types)
(sender_sock, receiver_sock)
for sender_sock, receiver_sock in itertools.product(sender_sockets, receiver_sockets)
if _types_are_compatible(sender_sock.type, receiver_sock.type)
]

# No connections seem to be possible
if not possible_connections:
connections_status_str = _connections_status(
from_node=from_node, from_sockets=from_sockets, to_node=to_node, to_sockets=to_sockets
sender_node=sender_node,
sender_sockets=sender_sockets,
receiver_node=receiver_node,
receiver_sockets=receiver_sockets,
)

# Both sockets were specified: explain why the types don't match
if len(sender_sockets) == len(receiver_sockets) and len(sender_sockets) == 1:
raise PipelineConnectError(
f"Cannot connect '{sender_node}.{sender_sockets[0].name}' with '{receiver_node}.{receiver_sockets[0].name}': "
f"their declared input and output types do not match.\n{connections_status_str}"
)

# Not both sockets were specified: explain there's no possible match on any pair
connections_status_str = _connections_status(
sender_node=sender_node,
sender_sockets=sender_sockets,
receiver_node=receiver_node,
receiver_sockets=receiver_sockets,
)
raise PipelineConnectError(
f"Cannot connect '{from_node}' with '{to_node}': "
f"Cannot connect '{sender_node}' with '{receiver_node}': "
f"no matching connections available.\n{connections_status_str}"
)

Expand All @@ -56,40 +113,46 @@ def find_unambiguous_connection(
# TODO allow for multiple connections at once if there is no ambiguity?
# TODO give priority to sockets that have no default values?
connections_status_str = _connections_status(
from_node=from_node, from_sockets=from_sockets, to_node=to_node, to_sockets=to_sockets
sender_node=sender_node,
sender_sockets=sender_sockets,
receiver_node=receiver_node,
receiver_sockets=receiver_sockets,
)
raise PipelineConnectError(
f"Cannot connect '{from_node}' with '{to_node}': more than one connection is possible "
f"Cannot connect '{sender_node}' with '{receiver_node}': more than one connection is possible "
"between these components. Please specify the connection name, like: "
f"pipeline.connect('{from_node}.{possible_connections[0][0].name}', "
f"'{to_node}.{possible_connections[0][1].name}').\n{connections_status_str}"
f"pipeline.connect('{sender_node}.{possible_connections[0][0].name}', "
f"'{receiver_node}.{possible_connections[0][1].name}').\n{connections_status_str}"
)

return possible_connections[0]


def _connections_status(from_node: str, to_node: str, from_sockets: List[OutputSocket], to_sockets: List[InputSocket]):
def _connections_status(
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
):
"""
Lists the status of the sockets, for error messages.
"""
from_sockets_entries = []
for from_socket in from_sockets:
socket_types = ", ".join([_get_socket_type_desc(t) for t in from_socket.types])
from_sockets_entries.append(f" - {from_socket.name} ({socket_types})")
from_sockets_list = "\n".join(from_sockets_entries)

to_sockets_entries = []
for to_socket in to_sockets:
socket_types = ", ".join([_get_socket_type_desc(t) for t in to_socket.types])
to_sockets_entries.append(
f" - {to_socket.name} ({socket_types}), {'sent by '+to_socket.sender if to_socket.sender else 'available'}"
sender_sockets_entries = []
for sender_socket in sender_sockets:
socket_types = get_socket_type_desc(sender_socket.type)
sender_sockets_entries.append(f" - {sender_socket.name} ({socket_types})")
sender_sockets_list = "\n".join(sender_sockets_entries)

receiver_sockets_entries = []
for receiver_socket in receiver_sockets:
socket_types = get_socket_type_desc(receiver_socket.type)
receiver_sockets_entries.append(
f" - {receiver_socket.name} ({socket_types}), "
f"{'sent by '+receiver_socket.sender if receiver_socket.sender else 'available'}"
)
to_sockets_list = "\n".join(to_sockets_entries)
receiver_sockets_list = "\n".join(receiver_sockets_entries)

return f"'{from_node}':\n{from_sockets_list}\n'{to_node}':\n{to_sockets_list}"
return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"


def _get_socket_type_desc(type_):
def get_socket_type_desc(type_):
"""
Assembles a readable representation of the type of a connection. Can handle primitive types, classes, and
arbitrarily nested structures of types from the typing module.
Expand Down Expand Up @@ -126,5 +189,5 @@ def _get_socket_type_desc(type_):
else:
type_name = type_.__name__

subtypes = ", ".join([_get_socket_type_desc(subtype) for subtype in args if subtype is not type(None)])
subtypes = ", ".join([get_socket_type_desc(subtype) for subtype in args if subtype is not type(None)])
return f"{type_name}[{subtypes}]"
31 changes: 11 additions & 20 deletions canals/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from canals.draw import draw, convert_for_debug, RenderingEngines
from canals.pipeline.sockets import InputSocket, OutputSocket, find_input_sockets, find_output_sockets
from canals.pipeline.validation import validate_pipeline_input
from canals.pipeline.connections import parse_connection_name, find_unambiguous_connection
from canals.pipeline.connections import parse_connection_name, find_unambiguous_connection, get_socket_type_desc


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -176,7 +176,7 @@ def connect(self, connect_from: str, connect_to: str) -> None:
f"'{from_node}.{from_socket_name} does not exist. "
f"Output connections of {from_node} are: "
+ ", ".join(
[f"{name} (type {[t.__name__ for t in socket.types]})" for name, socket in from_sockets.items()]
[f"{name} (type {get_socket_type_desc(socket.type)})" for name, socket in from_sockets.items()]
)
)
if to_socket_name:
Expand All @@ -186,35 +186,26 @@ def connect(self, connect_from: str, connect_to: str) -> None:
f"'{to_node}.{to_socket_name} does not exist. "
f"Input connections of {to_node} are: "
+ ", ".join(
[f"{name} (type {[t.__name__ for t in socket.types]})" for name, socket in to_sockets.items()]
[f"{name} (type {get_socket_type_desc(socket.type)})" for name, socket in to_sockets.items()]
)
)

# If either one of the two sockets is not specified, look for an unambiguous connection
# Look for an unambiguous connection among the possible ones.
# Note that if there is more than one possible connection but two sockets match by name, they're paired.
if not to_socket_name or not from_socket_name:
from_sockets = [from_socket] if from_socket_name else from_sockets.values()
to_sockets = [to_socket] if to_socket_name else to_sockets.values()
from_socket, to_socket = find_unambiguous_connection(
from_node=from_node, from_sockets=from_sockets, to_node=to_node, to_sockets=to_sockets
)
from_sockets = [from_socket] if from_socket_name else list(from_sockets.values())
to_sockets = [to_socket] if to_socket_name else list(to_sockets.values())
from_socket, to_socket = find_unambiguous_connection(
sender_node=from_node, sender_sockets=from_sockets, receiver_node=to_node, receiver_sockets=to_sockets
)

# Connect the components on these sockets
self._direct_connect(from_node=from_node, from_socket=from_socket, to_node=to_node, to_socket=to_socket)

def _direct_connect(self, from_node: str, from_socket: OutputSocket, to_node: str, to_socket: InputSocket) -> None:
"""
Directly connect socket to socket.
Directly connect socket to socket. This method does not type-check the connections: use 'Pipeline.connect()'
instead (which uses 'find_unambiguous_connection()' to validate types).
"""
# Verify that receiving socket can accept the output types it will receive
if Any not in to_socket.types and not from_socket.types & to_socket.types:
raise PipelineConnectError(
f"Cannot connect '{from_node}.{from_socket.name}' with '{to_node}.{to_socket.name}': "
f"their declared input and output types do not match.\n"
f" - {from_node}.{from_socket.name}: {[t.__name__ for t in from_socket.types]}\n"
f" - {to_node}.{to_socket.name}: {[t.__name__ for t in to_socket.types]}\n"
)

# Make sure the receiving socket isn't already connected - sending sockets can be connected as many times as needed,
# so they don't need this check
if to_socket.sender:
Expand Down
18 changes: 6 additions & 12 deletions canals/pipeline/sockets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Union, Optional, Dict, Set, Any, get_origin, get_args
from typing import Optional, Dict

import logging
from dataclasses import dataclass, fields, Field
from dataclasses import dataclass, fields


logger = logging.getLogger(__name__)
Expand All @@ -13,31 +13,25 @@
@dataclass
class OutputSocket:
name: str
types: Set[type]
type: type


@dataclass
class InputSocket:
name: str
types: Set[type]
type: type
sender: Optional[str] = None


def find_input_sockets(component) -> Dict[str, InputSocket]:
"""
Find a component's input sockets.
"""
return {f.name: InputSocket(name=f.name, types=_get_types(f)) for f in fields(component.__canals_input__)}
return {f.name: InputSocket(name=f.name, type=f.type) for f in fields(component.__canals_input__)}


def find_output_sockets(component) -> Dict[str, OutputSocket]:
"""
Find a component's output sockets.
"""
return {f.name: OutputSocket(name=f.name, types=_get_types(f)) for f in fields(component.__canals_output__)}


def _get_types(field: Field) -> Set[Any]:
if get_origin(field.type) is Union:
return {t for t in get_args(field.type) if t is not type(None)}
return {field.type}
return {f.name: OutputSocket(name=f.name, type=f.type) for f in fields(component.__canals_output__)}
Loading

0 comments on commit 781af48

Please sign in to comment.