Skip to content

Commit

Permalink
fix: deepcopy the inputs of components (#6987)
Browse files Browse the repository at this point in the history
* deepcopy inputs

* reno

* simplify test
  • Loading branch information
anakin87 committed Feb 16, 2024
1 parent b645c16 commit 3f85a63
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 1 deletion.
8 changes: 7 additions & 1 deletion haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools
import logging
from collections import defaultdict
from copy import copy
from copy import copy, deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Set, TextIO, Tuple, Type, TypeVar, Union
Expand Down Expand Up @@ -686,6 +686,12 @@ def run(self, word: str):
# never received by this method. It's handled by the `run()` method of the `Pipeline` class
# defined in `haystack/pipeline.py`.
# As of now we're ok with this, but we'll need to merge those two classes at some point.

# deepcopying the inputs prevents the Pipeline run logic from being altered unexpectedly
# when the same input reference is passed to multiple components.
for component_name, component_inputs in data.items():
data[component_name] = {k: deepcopy(v) for k, v in component_inputs.items()}

for component_name, component_inputs in data.items():
if component_name not in self.graph.nodes:
# This is not a component name, it must be the name of one or more input sockets.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Previously, when using the same input reference in different components, the Pipeline run logic had an
unexpected behavior. This has been fixed by deepcopying the inputs before passing them to the components.
83 changes: 83 additions & 0 deletions test/core/pipeline/test_same_input_different_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import List

from haystack import Pipeline
from haystack.dataclasses import ChatMessage
from haystack import component


class MethodTracker:
# This class is used to track the number of times a method is called and with which arguments
def __init__(self, method):
self.method = method
self.call_count = 0
self.called_with = None

def __call__(self, *args, **kwargs):
self.call_count += 1
self.called_with = (args, kwargs)
return self.method(*args, **kwargs)


@component
class PassThroughPromptBuilder:
# This is a pass-through component that returns the same input
@component.output_types(prompt=List[ChatMessage])
def run(self, prompt_source: List[ChatMessage]):
return {"prompt": prompt_source}


@component
class MessageMerger:
@component.output_types(merged_message=str)
def run(self, messages: List[ChatMessage], metadata: dict = None):
return {"merged_message": "\n".join(t.content for t in messages)}


@component
class FakeGenerator:
# This component is a fake generator that always returns the same message
@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage]):
return {"replies": [ChatMessage.from_assistant("Fake message")]}


def test_same_input_different_components():
"""
Test that passing the same input reference to different components
does not alter the correct Pipeline run logic.
"""

prompt_builder = PassThroughPromptBuilder()
llm = FakeGenerator()
mm1 = MessageMerger()
mm2 = MessageMerger()

mm1_tracked_run = MethodTracker(mm1.run)
mm1.run = mm1_tracked_run

mm2_tracked_run = MethodTracker(mm2.run)
mm2.run = mm2_tracked_run

pipe = Pipeline()
pipe.add_component("prompt_builder", prompt_builder)
pipe.add_component("llm", llm)
pipe.add_component("mm1", mm1)
pipe.add_component("mm2", mm2)

pipe.connect("prompt_builder.prompt", "llm.messages")
pipe.connect("prompt_builder.prompt", "mm1")
pipe.connect("llm.replies", "mm2")

messages = [
ChatMessage.from_system("Always respond in English even if some input data is in other languages."),
ChatMessage.from_user("Tell me about Berlin"),
]
params = {"metadata": {"metadata_key": "metadata_value", "meta2": "value2"}}

pipe.run(data={"prompt_builder": {"prompt_source": messages}, "mm1": params, "mm2": params})

assert mm1_tracked_run.call_count == 1
assert mm1_tracked_run.called_with[1]["metadata"] == params["metadata"]

assert mm2_tracked_run.call_count == 1
assert mm2_tracked_run.called_with[1]["metadata"] == params["metadata"]

0 comments on commit 3f85a63

Please sign in to comment.