Skip to content

Commit

Permalink
Merge pull request #38 from deepset-ai/component-protocol
Browse files Browse the repository at this point in the history
Create Component Protocol
  • Loading branch information
silvanocerza committed Jun 30, 2023
2 parents 8d8dc57 + 4677c39 commit f48d5d8
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 9 deletions.
2 changes: 1 addition & 1 deletion canals/component/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from canals.component.component import component, ComponentError
from canals.component.component import component, Component, ComponentError
from canals.component.decorators import save_init_params
41 changes: 38 additions & 3 deletions canals/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import inspect
from typing import Union, List, get_origin, get_args
from typing import Protocol, Union, List, Any, get_origin, get_args

from dataclasses import fields, Field

Expand All @@ -15,6 +15,43 @@
logger = logging.getLogger(__name__)


# We ignore too-few-public-methods Pylint error as this is only meant to be
# the definition of the Component interface.
# A concrete Component will have more than method in any case.
class Component(Protocol): # pylint: disable=too-few-public-methods
"""
Abstract interface of a Component.
This is only used by type checking tools.
If you want to create a new Component use the @component decorator.
"""

def run(self, data: Any) -> Any:
"""
Takes the Component input and returns its output.
Input and output dataclasses types must be defined in separate methods
decorated with @component.input and @component.output respectively.
We use Any both as data and return types since dataclasses don't have a specific type.
"""

@property
def __canals_input__(self) -> type:
pass

@property
def __canals_output__(self) -> type:
pass

@property
def __canals_optional_inputs__(self) -> List[str]:
pass

@property
def __canals_mandatory_inputs__(self) -> List[str]:
pass


class _Component:
"""
Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
Expand Down Expand Up @@ -324,7 +361,6 @@ def _is_optional(field: Field) -> bool:
return get_origin(field.type) is Union and type(None) in get_args(field.type)


# TODO: Remember to set the self type to Component when we create its Protocol
def _optional_inputs(self) -> List[str]:
"""
Return all field names of self that have an Optional type.
Expand All @@ -333,7 +369,6 @@ def _optional_inputs(self) -> List[str]:
return [f.name for f in fields(self.__canals_input__) if _is_optional(f)]


# TODO: Remember to set the self type to Component when we create its Protocol
def _mandatory_inputs(self) -> List[str]:
"""
Return all field names of self that don't have an Optional type.
Expand Down
5 changes: 3 additions & 2 deletions canals/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import networkx

from canals.component import Component
from canals.errors import PipelineConnectError, PipelineMaxLoops, PipelineRuntimeError, PipelineValidationError
from canals.draw import draw, convert_for_debug, RenderingEngines
from canals.pipeline.sockets import InputSocket, OutputSocket, find_input_sockets, find_output_sockets
Expand Down Expand Up @@ -87,7 +88,7 @@ def _comparable_nodes_list(self, graph: networkx.MultiDiGraph) -> List[Dict[str,
nodes.sort()
return nodes

def add_component(self, name: str, instance: Any) -> None:
def add_component(self, name: str, instance: Component) -> None:
"""
Create a component for the given component. Components are not connected to anything by default:
use `Pipeline.connect()` to connect components together.
Expand Down Expand Up @@ -226,7 +227,7 @@ def _direct_connect(self, from_node: str, from_socket: OutputSocket, to_node: st
# Stores the name of the node that will send its output to this socket
to_socket.sender = from_node

def get_component(self, name: str) -> object:
def get_component(self, name: str) -> Component:
"""
Returns an instance of a component.
Expand Down
6 changes: 3 additions & 3 deletions canals/pipeline/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
from pathlib import Path

from canals.component import component
from canals.component import component, Component
from canals.pipeline.pipeline import Pipeline
from canals.errors import PipelineUnmarshalError

Expand Down Expand Up @@ -70,7 +70,7 @@ def marshal_pipelines(pipelines: Dict[str, Pipeline]) -> Dict[str, Any]:
schema: Dict[str, Any] = {}

# Summarize pipeline configuration
components: List[Tuple[str, str, object]] = []
components: List[Tuple[str, str, Component]] = []
schema["pipelines"] = {}
for pipeline_name, pipeline in pipelines.items():
pipeline_repr: Dict[str, Any] = {}
Expand Down Expand Up @@ -122,7 +122,7 @@ def unmarshal_pipelines(schema: Dict[str, Any]) -> Dict[str, Pipeline]: # pylin
"""
pipelines = {}
component_instances: Dict[str, object] = {}
component_instances: Dict[str, Component] = {}
for pipeline_name, pipeline_schema in schema["pipelines"].items():
# Create the Pipeline object
pipe_args = {"metadata": pipeline_schema.get("metadata", None)}
Expand Down

0 comments on commit f48d5d8

Please sign in to comment.