Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BranchJoiner and deprecate Multiplexer #7765

Merged
merged 7 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/pydoc/config/joiners_api.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/joiners]
modules: ["document_joiner"]
modules: ["document_joiner", "branch"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
5 changes: 3 additions & 2 deletions haystack/components/joiners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from haystack.components.joiners.document_joiner import DocumentJoiner
from .branch import BranchJoiner
from .document_joiner import DocumentJoiner

__all__ = ["DocumentJoiner"]
__all__ = ["DocumentJoiner", "BranchJoiner"]
141 changes: 141 additions & 0 deletions haystack/components/joiners/branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Type

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.core.component.types import Variadic
from haystack.utils import deserialize_type, serialize_type

logger = logging.getLogger(__name__)


@component(is_greedy=True)
class BranchJoiner:
"""
A component to join different branches of a pipeline into one single output.

`BranchJoiner` receives multiple data connections of the same type from other components and passes the first
value coming to its single output, possibly distributing it to various other components.
masci marked this conversation as resolved.
Show resolved Hide resolved

`BranchJoiner` is fundamental to close loops in a pipeline, where the two branches it joins are the ones
coming from the previous component and one coming back from a loop. For example, `BranchJoiner` could be used
to send data to a component evaluating errors. `BranchJoiner` would receive two connections, one to get the
original data and another one to get modified data in case there was an error. In both cases, `BranchJoiner`
would send (or re-send in case of a loop) data to the component evaluating errors. See "Usage example" below.

Another use case with a need for `BranchJoiner` is to reconcile multiple branches coming out of a decision
or Classifier component. For example, in a RAG pipeline, there might be a "query language classifier" component
sending the query to different retrievers, selecting one specifically according to the detected language. After the
retrieval step the pipeline would ideally continue with a `PromptBuilder`, and since we don't know in advance the
language of the query, all the retrievers should be ideally connected to the single `PromptBuilder`. Since the
`PromptBuilder` won't accept more than one connection in input, we would connect all the retrievers to a
`BranchJoiner` component and reconcile them in a single output that can be connected to the `PromptBuilder`
downstream.

Usage example:

```python
import json
from typing import List

from haystack import Pipeline
from haystack.components.converters import OutputAdapter
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.joiners import BranchJoiner
from haystack.components.validators import JsonSchemaValidator
from haystack.dataclasses import ChatMessage

person_schema = {
"type": "object",
"properties": {
"first_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"},
"last_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"},
"nationality": {"type": "string", "enum": ["Italian", "Portuguese", "American"]},
},
"required": ["first_name", "last_name", "nationality"]
}

# Initialize a pipeline
pipe = Pipeline()

# Add components to the pipeline
pipe.add_component('joiner', BranchJoiner(List[ChatMessage]))
pipe.add_component('fc_llm', OpenAIChatGenerator(model="gpt-3.5-turbo-0125"))
pipe.add_component('validator', JsonSchemaValidator(json_schema=person_schema))
pipe.add_component('adapter', OutputAdapter("{{chat_message}}", List[ChatMessage])),
# And connect them
pipe.connect("adapter", "joiner")
pipe.connect("joiner", "fc_llm")
pipe.connect("fc_llm.replies", "validator.messages")
pipe.connect("validator.validation_error", "joiner")

result = pipe.run(data={"fc_llm": {"generation_kwargs": {"response_format": {"type": "json_object"}}},
"adapter": {"chat_message": [ChatMessage.from_user("Create json object from Peter Parker")]}})

print(json.loads(result["validator"]["validated"][0].content))


>> {'first_name': 'Peter', 'last_name': 'Parker', 'nationality': 'American', 'name': 'Spider-Man', 'occupation':
>> 'Superhero', 'age': 23, 'location': 'New York City'}
```

Note that `BranchJoiner` can manage only one data type at a time. In this case, `BranchJoiner` is created for passing
`List[ChatMessage]`. This determines the type of data that `BranchJoiner` will receive from the upstream connected
components and also the type of data that `BranchJoiner` will send through its output.

In the code example, `BranchJoiner` receives a looped back `List[ChatMessage]` from the `JsonSchemaValidator` and
sends it down to the `OpenAIChatGenerator` for re-generation. We can have multiple loopback connections in the
pipeline. In this instance, the downstream component is only one (the `OpenAIChatGenerator`), but the pipeline might
have more than one downstream component.
"""

def __init__(self, type_: Type):
"""
Create a `BranchJoiner` component.

:param type_: The type of data that the `BranchJoiner` will receive from the upstream connected components and
distribute to the downstream connected components.
"""
self.type_ = type_
# type_'s type can't be determined statically
component.set_input_types(self, value=Variadic[type_]) # type: ignore
component.set_output_types(self, value=type_)

def to_dict(self):
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(self, type_=serialize_type(self.type_))

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BranchJoiner":
"""
Deserializes the component from a dictionary.

:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
data["init_parameters"]["type_"] = deserialize_type(data["init_parameters"]["type_"])
return default_from_dict(cls, data)

def run(self, **kwargs):
"""
The run method of the `BranchJoiner` component.

Multiplexes the input data from the upstream connected components and distributes it to the downstream connected
components.

:param **kwargs: The input data. Must be of the type declared in `__init__`.
:return: A dictionary with the following keys:
- `value`: The input data.
"""
if (inputs_count := len(kwargs["value"])) != 1:
raise ValueError(f"BranchJoiner expects only one input, but {inputs_count} were received.")
return {"value": kwargs["value"][0]}
5 changes: 5 additions & 0 deletions haystack/components/others/multiplexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import sys
import warnings
from typing import Any, Dict

from haystack import component, default_from_dict, default_to_dict, logging
Expand Down Expand Up @@ -103,6 +104,10 @@ def __init__(self, type_: TypeAlias):
:param type_: The type of data that the `Multiplexer` will receive from the upstream connected components and
distribute to the downstream connected components.
"""
warnings.warn(
"`Multiplexer` is deprecated and will be removed in Haystack 2.4.0. Use `joiners.BranchJoiner` instead.",
bilgeyucel marked this conversation as resolved.
Show resolved Hide resolved
DeprecationWarning,
)
self.type_ = type_
component.set_input_types(self, value=Variadic[type_])
component.set_output_types(self, value=type_)
Expand Down
14 changes: 14 additions & 0 deletions releasenotes/notes/add-branch-joiner-037298459ca74077.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
highlights: >
The `Multiplexer` component proved to be hard to explain and to understand. After reviewing its use cases, the documentation
was rewritten and the component was renamed to `BranchJoiner` to better explain its functionalities.
upgrade:
- |
`BranchJoiner` has the very same interface as `Multiplexer`. To upgrade your code, just rename any occurrence
of `Multiplexer` to `BranchJoiner` and ajdust the imports accordingly.
features:
- |
Add `BranchJoiner` to eventually replace `Multiplexer`
deprecations:
- |
`Mulitplexer` is now deprecated.
35 changes: 35 additions & 0 deletions test/components/joiners/test_branch_joiner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import pytest

from haystack.components.joiners import BranchJoiner


class TestBranchJoiner:
def test_one_value(self):
joiner = BranchJoiner(int)
output = joiner.run(value=[2])
assert output == {"value": 2}

def test_one_value_of_wrong_type(self):
# BranchJoiner does not type check the input
joiner = BranchJoiner(int)
output = joiner.run(value=["hello"])
assert output == {"value": "hello"}

def test_one_value_of_none_type(self):
# BranchJoiner does not type check the input
joiner = BranchJoiner(int)
output = joiner.run(value=[None])
assert output == {"value": None}

def test_more_values_of_expected_type(self):
joiner = BranchJoiner(int)
with pytest.raises(ValueError, match="BranchJoiner expects only one input, but 3 were received."):
joiner.run(value=[2, 3, 4])

def test_no_values(self):
joiner = BranchJoiner(int)
with pytest.raises(ValueError, match="BranchJoiner expects only one input, but 0 were received."):
joiner.run(value=[])