-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: deepcopy the inputs of components (#6987)
* deepcopy inputs * reno * simplify test
- Loading branch information
Showing
3 changed files
with
95 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 5 additions & 0 deletions
5
releasenotes/notes/pipeline-same-input-ref-different-components-68d74cb17b35f8db.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
fixes: | ||
- | | ||
Previously, when using the same input reference in different components, the Pipeline run logic had an | ||
unexpected behavior. This has been fixed by deepcopying the inputs before passing them to the components. |
83 changes: 83 additions & 0 deletions
83
test/core/pipeline/test_same_input_different_components.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import List | ||
|
||
from haystack import Pipeline | ||
from haystack.dataclasses import ChatMessage | ||
from haystack import component | ||
|
||
|
||
class MethodTracker: | ||
# This class is used to track the number of times a method is called and with which arguments | ||
def __init__(self, method): | ||
self.method = method | ||
self.call_count = 0 | ||
self.called_with = None | ||
|
||
def __call__(self, *args, **kwargs): | ||
self.call_count += 1 | ||
self.called_with = (args, kwargs) | ||
return self.method(*args, **kwargs) | ||
|
||
|
||
@component | ||
class PassThroughPromptBuilder: | ||
# This is a pass-through component that returns the same input | ||
@component.output_types(prompt=List[ChatMessage]) | ||
def run(self, prompt_source: List[ChatMessage]): | ||
return {"prompt": prompt_source} | ||
|
||
|
||
@component | ||
class MessageMerger: | ||
@component.output_types(merged_message=str) | ||
def run(self, messages: List[ChatMessage], metadata: dict = None): | ||
return {"merged_message": "\n".join(t.content for t in messages)} | ||
|
||
|
||
@component | ||
class FakeGenerator: | ||
# This component is a fake generator that always returns the same message | ||
@component.output_types(replies=List[ChatMessage]) | ||
def run(self, messages: List[ChatMessage]): | ||
return {"replies": [ChatMessage.from_assistant("Fake message")]} | ||
|
||
|
||
def test_same_input_different_components(): | ||
""" | ||
Test that passing the same input reference to different components | ||
does not alter the correct Pipeline run logic. | ||
""" | ||
|
||
prompt_builder = PassThroughPromptBuilder() | ||
llm = FakeGenerator() | ||
mm1 = MessageMerger() | ||
mm2 = MessageMerger() | ||
|
||
mm1_tracked_run = MethodTracker(mm1.run) | ||
mm1.run = mm1_tracked_run | ||
|
||
mm2_tracked_run = MethodTracker(mm2.run) | ||
mm2.run = mm2_tracked_run | ||
|
||
pipe = Pipeline() | ||
pipe.add_component("prompt_builder", prompt_builder) | ||
pipe.add_component("llm", llm) | ||
pipe.add_component("mm1", mm1) | ||
pipe.add_component("mm2", mm2) | ||
|
||
pipe.connect("prompt_builder.prompt", "llm.messages") | ||
pipe.connect("prompt_builder.prompt", "mm1") | ||
pipe.connect("llm.replies", "mm2") | ||
|
||
messages = [ | ||
ChatMessage.from_system("Always respond in English even if some input data is in other languages."), | ||
ChatMessage.from_user("Tell me about Berlin"), | ||
] | ||
params = {"metadata": {"metadata_key": "metadata_value", "meta2": "value2"}} | ||
|
||
pipe.run(data={"prompt_builder": {"prompt_source": messages}, "mm1": params, "mm2": params}) | ||
|
||
assert mm1_tracked_run.call_count == 1 | ||
assert mm1_tracked_run.called_with[1]["metadata"] == params["metadata"] | ||
|
||
assert mm2_tracked_run.call_count == 1 | ||
assert mm2_tracked_run.called_with[1]["metadata"] == params["metadata"] |