Skip to content

Commit

Permalink
fix: Fix run order of variadic greedy components in Pipeline.run() (#…
Browse files Browse the repository at this point in the history
…7258)

* Fix run order of variadic greedy components in Pipeline.run()

* Add release notes
  • Loading branch information
silvanocerza committed Mar 1, 2024
1 parent 3077a08 commit 72d776c
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
11 changes: 11 additions & 0 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,17 @@ def run(self, word: str):
last_inputs[receiver_component_name][edge_data["to_socket"].name] = value

pair = (receiver_component_name, self.graph.nodes[receiver_component_name]["instance"])
is_greedy = pair[1].__haystack_is_greedy__
is_variadic = edge_data["to_socket"].is_variadic
if is_variadic and is_greedy:
# If the receiver is greedy, we can run it right away.
# First we remove it from the lists it's in if it's there or we risk running it multiple times.
if pair in to_run:
to_run.remove(pair)
if pair in waiting_for_input:
waiting_for_input.remove(pair)
to_run.append(pair)

if pair not in waiting_for_input and pair not in to_run:
to_run.append(pair)

Expand Down
6 changes: 6 additions & 0 deletions releasenotes/notes/run-greedy-fix-6d4559126e7739ce.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
Fix `Pipeline.run()` mistakenly running a Component before it should.
This can happen when a greedy variadic Component must be executed before a
Component with default inputs.
46 changes: 45 additions & 1 deletion test/core/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Optional
from typing import List, Optional
from unittest.mock import patch

import pytest

from haystack import Document
from haystack.components.builders import PromptBuilder
from haystack.components.others import Multiplexer
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.core.component import component
from haystack.core.component.types import InputSocket, OutputSocket
from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineMaxLoops, PipelineRuntimeError
from haystack.core.pipeline import Pipeline, PredefinedPipeline
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.testing.factory import component_class
from haystack.testing.sample_components import AddFixedValue, Double

Expand All @@ -27,6 +32,45 @@ def run(self, input_: str):
return {"value": input_}


def test_run_with_greedy_variadic_after_component_with_default_input_simple(spying_tracer):
"""
This test verifies that `Pipeline.run()` executes the components in the correct order when
there's a greedy Component with variadic input right before a Component with at least one default input.
We use the `spying_tracer` fixture to simplify the code to verify the order of execution.
This creates some coupling between this test and how we trace the Pipeline execution.
A worthy tradeoff in my opinion, we will notice right away if we change either the run logic or
the tracing logic.
"""
document_store = InMemoryDocumentStore()
document_store.write_documents([Document(content="This is a simple document")])

pipeline = Pipeline()
template = "Given this documents: {{ documents|join(', ', attribute='content') }} Answer this question: {{ query }}"
pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store))
pipeline.add_component("prompt_builder", PromptBuilder(template=template))
pipeline.add_component("multiplexer", Multiplexer(List[Document]))

pipeline.connect("retriever", "multiplexer")
pipeline.connect("multiplexer", "prompt_builder.documents")
res = pipeline.run({"query": "This is my question"})

assert res == {
"prompt_builder": {
"prompt": "Given this documents: This is a simple document Answer this question: This is my question"
}
}

assert len(spying_tracer.spans) == 4
assert spying_tracer.spans[0].operation_name == "haystack.pipeline.run"
assert spying_tracer.spans[1].operation_name == "haystack.component.run"
assert spying_tracer.spans[1].tags["haystack.component.name"] == "retriever"
assert spying_tracer.spans[2].operation_name == "haystack.component.run"
assert spying_tracer.spans[2].tags["haystack.component.name"] == "multiplexer"
assert spying_tracer.spans[3].operation_name == "haystack.component.run"
assert spying_tracer.spans[3].tags["haystack.component.name"] == "prompt_builder"


def test_pipeline_resolution_simple_input():
@component
class Hello:
Expand Down

0 comments on commit 72d776c

Please sign in to comment.