Skip to content

Commit

Permalink
Merge pull request #23 from deepset-ai/component-io
Browse files Browse the repository at this point in the history
Rework how component I/O is defined
  • Loading branch information
silvanocerza committed Jun 21, 2023
2 parents 4e413c7 + a44b965 commit 8ac5461
Show file tree
Hide file tree
Showing 34 changed files with 1,060 additions and 861 deletions.
1 change: 0 additions & 1 deletion canals/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@
# SPDX-License-Identifier: Apache-2.0
from canals.component.component import component, ComponentError
from canals.component.decorators import save_init_params
from canals.component.input_output import ComponentInput, ComponentOutput, VariadicComponentInput
99 changes: 54 additions & 45 deletions canals/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import logging
import inspect

from dataclasses import is_dataclass

from canals.errors import ComponentError
from canals.component.decorators import save_init_params, init_defaults
from canals.component.input_output import Connection, _input, _output


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -187,11 +186,13 @@ def run(self, data: <Input if defined, otherwise untyped>) -> <Output if defined
# Its value is set to the desired component name: normally it is the class name, but it can technically be customized.
class_.__canals_component__ = class_.__name__

# Check that inputs respects all constraints
_check_input(class_)
# Find input and output properties
(input_, output) = _find_input_output(class_)

# Check that outputs respects all constraints
_check_output(class_)
# Save the input and output properties so it's easier to find them when running the Component since we won't
# need to search the exact property name each time
class_.__canals_input__ = input_
class_.__canals_output__ = output

# Check that the run method respects all constraints
_check_run_signature(class_)
Expand All @@ -206,39 +207,55 @@ def run(self, data: <Input if defined, otherwise untyped>) -> <Output if defined
return class_


def _check_input(class_):
"""
Check that the component's input respects all constraints
"""
if not hasattr(class_, "Input") and not hasattr(class_, "input_type"):
raise ComponentError(
"Components must either have an Input dataclass or a 'input_type' property that returns such dataclass"
)
if hasattr(class_, "Input"):
if not is_dataclass(class_.Input):
raise ComponentError(f"{class_.__name__}.Input must be a dataclass")
if not hasattr(class_.Input, "_component_input"):
raise ComponentError(f"{class_.__name__}.Input must inherit from ComponentInput")
if (
hasattr(class_.Input, "_variadic_component_input")
and len(inspect.signature(class_.Input.__init__).parameters) != 2
):
raise ComponentError("Variadic inputs can contain only one variadic positional parameter.")


def _check_output(class_):
# We do this to have some namespacing and also to make it clear that the methods decorated with
# @component.input and @component.output must have their class decorated as @component.
setattr(component, "input", _input)
setattr(component, "output", _output)


def _find_input_output(class_):
"""
Check that the component's output respects all constraints
Finds the input and the output definitions for class_ and returns them.
There must be only a single definition of input and output for class_, if either
none or more than one are found raise ConnectionError.
"""
if not hasattr(class_, "Output") and not hasattr(class_, "output_type"):
raise ComponentError(
"Components must either have an Output dataclass or a 'output_type' property that returns such dataclass"
)
if hasattr(class_, "Output"):
if not is_dataclass(class_.Output):
raise ComponentError(f"{class_.__name__}.Output must be a dataclass")
if not hasattr(class_.Output, "_component_output"):
raise ComponentError(f"{class_.__name__}.Output must inherit from ComponentOutput")
inputs_found = []
outputs_found = []

# Get all properties of class_
properties = inspect.getmembers(class_, predicate=lambda m: isinstance(m, property))
for _, prop in properties:
if not hasattr(prop, "fget") and not hasattr(prop.fget, "__canals_connection__"):
continue

# Field __canals_connection__ is set by _input and _output decorators
if prop.fget.__canals_connection__ in [Connection.INPUT, Connection.INPUT_VARIADIC]:
inputs_found.append(prop)
elif prop.fget.__canals_connection__ == Connection.OUTPUT:
outputs_found.append(prop)

if (in_len := len(inputs_found)) != 1:
# Raise if we don't find only a single input definition
if in_len == 0:
raise ComponentError(
f"No input definition found in Component {class_.__name__}. "
"Create a method that returns a dataclass defining the input and "
"decorate it with @component.input() to fix the error."
)
raise ComponentError(f"Multiple input definitions found for Component {class_.__name__}.")

if (in_len := len(outputs_found)) != 1:
# Raise if we don't find only a single output definition
if in_len == 0:
raise ComponentError(
f"No output definition found in Component {class_.__name__}. "
"Create a method that returns a dataclass defining the output and "
"decorate it with @component.output() to fix the error."
)
raise ComponentError(f"Multiple output definitions found for Component {class_.__name__}.")

return (inputs_found[0], outputs_found[0])


def _check_run_signature(class_):
Expand All @@ -257,11 +274,3 @@ def _check_run_signature(class_):
# The input param must be called data
if not "data" in run_signature.parameters:
raise ComponentError("run() must accept a parameter called 'data'.")

# Either give a self.input_type function or type 'data' with the Input dataclass
if not hasattr(class_, "input_type") and run_signature.parameters["data"].annotation != class_.Input:
raise ComponentError(f"'data' must be typed and the type must be {class_.__name__}.Input.")

# Check for the return types
if not hasattr(class_, "output_type") and run_signature.return_annotation == inspect.Parameter.empty:
raise ComponentError(f"{class_.__name__}.run() must declare the type of its return value.")
164 changes: 118 additions & 46 deletions canals/component/input_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,78 +2,150 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
import inspect
from dataclasses import fields
from enum import Enum
from dataclasses import fields, is_dataclass, dataclass, asdict, MISSING

from canals.errors import ComponentError

logger = logging.getLogger(__name__)


class BaseIODataclass: # pylint: disable=too-few-public-methods
def _make_fields_optional(class_: type):
"""
Base class for input and output classes of components.
Takes a dataclass definition and modifies its __init__ so that all fields have
a default value set.
If a field has a default factory use it to set the default value.
If a field has neither a default factory or value default to None.
"""
defaults = []
for field in fields(class_):
default = field.default
if field.default is MISSING and field.default_factory is MISSING:
default = None
elif field.default is MISSING and field.default_factory is not MISSING:
default = field.default_factory()
defaults.append(default)
# mypy complains we're accessing __init__ on an instance but it's not in reality.
# class_ is a class definition and not an instance of it, so we're good.
# Also only I/O dataclasses are meant to be passed to this function making it a bit safer.
class_.__init__.__defaults__ = tuple(defaults) # type: ignore


def _make_comparable(class_: type):
"""
Overwrites the existing __eq__ method of class_ with a custom one.
This is meant to be used only in I/O dataclasses, it takes into account
whether the fields are marked as comparable or not.
def names(self):
"""
Returns the name of all the fields of this dataclass.
"""
return [field.name for field in fields(self)]

This is necessary since the automatically created __eq__ method in dataclasses
also verifies the type of the class. That causes it to fail if the I/O dataclass
is returned by a function.
class Optionalize(type):
"""
Makes all the fields of the dataclass optional by setting None as their default value.
In here we don't compare the types of self and other but only their fields.
"""

def __call__(cls, *args, **kwargs):
obj = cls.__new__(cls, *args, **kwargs)
obj.__init__.__func__.__defaults__ = tuple(None for _ in inspect.signature(obj.__init__).parameters)
obj.__init__(*args, **kwargs)
return obj
def comparator(self, other) -> bool:
if not is_dataclass(other):
return False

fields_ = [f.name for f in fields(self) if f.compare]
other_fields = [f.name for f in fields(other) if f.compare]
if not len(fields_) == len(other_fields):
return False

class Variadic(type):
"""
Adds the proper checks to a variadic input dataclass and packs the args into a list for the `__init__` call.
"""
self_dict, other_dict = asdict(self), asdict(other)
for field in fields_:
if not self_dict[field] == other_dict[field]:
return False

def __call__(cls, *args, **kwargs):
if kwargs:
raise ValueError(f"{cls.__name__} accepts only an unnamed list of positional parameters.")
return True

obj = cls.__new__(cls, *args)
setattr(class_, "__eq__", comparator)

if len(inspect.signature(obj.__init__).parameters) != 1:
raise ValueError(f"{cls.__name__} accepts only one variadic positional parameter.")

obj.__init__(list(args))
return obj
class Connection(Enum):
INPUT = 1
OUTPUT = 2
INPUT_VARIADIC = 3


class ComponentInput(BaseIODataclass, metaclass=Optionalize): # pylint: disable=too-few-public-methods
def _input(input_function=None, variadic: bool = False):
"""
Represents the input of a component.
Decorator to mark a method that returns a dataclass defining a Component's input.
The decorated function becomes a property.
:param variadic: Set it to true to mark the dataclass returned by input_function as variadic,
additional checks are done in this case, defaults to False
"""

# dataclasses are uncooperative (don't call `super()`), so we need this flag to check for inheritance
_component_input = True
def decorator(function):
def wrapper(self):
class_ = function(self)
# If the user didn't explicitly declare the returned class
# as dataclass we do it out of convenience
if not is_dataclass(class_):
class_ = dataclass(class_)

_make_comparable(class_)
_make_fields_optional(class_)

class VariadicComponentInput(BaseIODataclass, metaclass=Variadic): # pylint: disable=too-few-public-methods
"""
Represents the input of a variadic component.
"""
if variadic and len(fields(class_)) > 1:
raise ComponentError(f"Variadic input dataclass {class_.__name__} must have only one field")

if variadic:
# Ugly hack to make variadic input work
init = class_.__init__
class_.__init__ = lambda self, *args: init(self, list(args))

return class_

# VariadicComponentInput can't inherit from ComponentInput due to metaclasses clashes
# dataclasses are uncooperative (don't call `super()`), so we need this flag to check for inheritance
_component_input = True
_variadic_component_input = True
# Magic field to ease some further checks, we set it in the wrapper
# function so we access it like this <class>.<function>.fget.__canals_connection__
wrapper.__canals_connection__ = Connection.INPUT_VARIADIC if variadic else Connection.INPUT

# If we don't set the documentation explicitly the user wouldn't be able to access
# since we make wrapper a property and not the original function.
# This is not essential but a really nice to have.
return property(fget=wrapper, doc=function.__doc__)

class ComponentOutput(BaseIODataclass, metaclass=Optionalize): # pylint: disable=too-few-public-methods
# Check if we're called as @_input or @_input()
if input_function:
# Called with parens
return decorator(input_function)

# Called without parens
return decorator


def _output(output_function=None):
"""
Represents the output of a component.
Decorator to mark a method that returns a dataclass defining a Component's output.
The decorated function becomes a property.
"""

# dataclasses are uncooperative (don't call `super()`), so we need this flag to check for inheritance
_component_output = True
def decorator(function):
def wrapper(self):
class_ = function(self)
if not is_dataclass(class_):
class_ = dataclass(class_)
_make_comparable(class_)
return class_

# Magic field to ease some further checks, we set it in the wrapper
# function so we access it like this <class>.<function>.fget.__canals_connection__
wrapper.__canals_connection__ = Connection.OUTPUT

# If we don't set the documentation explicitly the user wouldn't be able to access
# since we make wrapper a property and not the original function.
# This is not essential but a really nice to have.
return property(fget=wrapper, doc=function.__doc__)

# Check if we're called as @_output or @_output()
if output_function:
# Called with parens
return decorator(output_function)

# Called without parens
return decorator
5 changes: 2 additions & 3 deletions canals/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import networkx

from canals.errors import PipelineConnectError, PipelineMaxLoops, PipelineRuntimeError, PipelineValidationError
from canals.component.input_output import ComponentInput
from canals.draw import draw, convert_for_debug, RenderingEngines
from canals.pipeline.sockets import InputSocket, OutputSocket, find_input_sockets, find_output_sockets
from canals.pipeline.validation import validate_pipeline_input
Expand Down Expand Up @@ -278,7 +277,7 @@ def warm_up(self):
logger.info("Warming up component %s...", node)
self.graph.nodes[node]["instance"].warm_up()

def run(self, data: Dict[str, ComponentInput], debug: bool = False) -> Dict[str, Any]:
def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
"""
Runs the pipeline.
Expand Down Expand Up @@ -568,7 +567,7 @@ def _run_component(self, name: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
logger.info("* Running %s (visits: %s)", name, self.graph.nodes[name]["visits"])
logger.debug(" '%s' inputs: %s", name, inputs)

input_class = instance.Input if hasattr(instance, "Input") else instance.input_type
input_class = instance.input

# If the node is variadic, unpack the input
if self.graph.nodes[name]["variadic_input"]:
Expand Down
9 changes: 3 additions & 6 deletions canals/pipeline/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def marshal_pipelines(pipelines: Dict[str, Pipeline]) -> Dict[str, Any]:
# Collect components
pipeline_repr["components"] = {}
for component_name in pipeline.graph.nodes:

# Check if we saved the same instance twice (or more times) and replace duplicates with references.
component_instance = pipeline.graph.nodes[component_name]["instance"]
for existing_component_pipeline, existing_component_name, existing_components in components:
Expand Down Expand Up @@ -136,7 +135,6 @@ def unmarshal_pipelines(schema: Dict[str, Any]) -> Dict[str, Pipeline]: # pylin
pipelines = {}
component_instances: Dict[str, object] = {}
for pipeline_name, pipeline_schema in schema["pipelines"].items():

# Create the Pipeline object
pipe_args = {"metadata": pipeline_schema.get("metadata", None)}
if "max_loops_allowed" in pipeline_schema.keys():
Expand Down Expand Up @@ -183,22 +181,21 @@ def _discover_dependencies(components: List[object]) -> List[str]:
return list({module.__name__.split(".")[0] for module in module_names if module is not None}) + ["canals"]


def _find_decorated_classes(modules_to_search: List[str], decorator: str = "__canals_component__") -> Dict[str, type]:
def _find_decorated_classes(modules_to_search: List[str], decorator: str = "__canals_component__") -> Dict[str, Any]:
"""
Finds all classes decorated with `@components` in all the modules listed in `modules_to_search`.
Returns a dictionary with the component class name and the component classes.
Note: can be used for other decorators as well by setting the `decorator` parameter.
"""
component_classes: Dict[str, type] = {}
component_classes: Dict[str, Any] = {}

# Collect all modules
for module in modules_to_search:

if not module in sys.modules:
raise ValueError(f"{module} is not imported.")

for name, entity in getmembers(sys.modules.get(module, None), ismodule):
for name, _ in getmembers(sys.modules.get(module, None), ismodule):
if f"{module}.{name}" in sys.modules:
modules_to_search.append(f"{module}.{name}")

Expand Down
Loading

0 comments on commit 8ac5461

Please sign in to comment.