Skip to content

Commit

Permalink
fix: add option to not override results by Shaper (#4231)
Browse files Browse the repository at this point in the history
* add  option to shaper and support answers

* remove publish restrictions on outputs

* support list
  • Loading branch information
tstadel authored and vblagoje committed Feb 22, 2023
1 parent ed29a29 commit 27ece02
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 2 deletions.
29 changes: 27 additions & 2 deletions haystack/nodes/other/shaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def __init__(
outputs: List[str],
inputs: Optional[Dict[str, Union[List[str], str]]] = None,
params: Optional[Dict[str, Any]] = None,
publish_outputs: Union[bool, List[str]] = True,
):
"""
Initializes the Shaper component.
Expand Down Expand Up @@ -319,14 +320,38 @@ def __init__(
You can use params to provide fallback values for arguments of `run` that you're not sure exist.
So if you need `query` to exist, you can provide a fallback value in the params, which will be used only if `query`
is not passed to this node by the pipeline.
:param outputs: THe key to store the outputs in the invocation context. The length of the outputs must match
:param outputs: The key to store the outputs in the invocation context. The length of the outputs must match
the number of outputs produced by the function invoked.
:param publish_outputs: Controls whether to publish the outputs to the pipeline's output.
Set `True` (default value) to publishes all outputs or `False` to publish None.
E.g. if `outputs = ["documents"]` result for `publish_outputs = True` looks like
```python
{
"invocation_context": {
"documents": [...]
},
"documents": [...]
}
```
For `publish_outputs = False` result looks like
```python
{
"invocation_context": {
"documents": [...]
},
}
```
If you want to have finer-grained control, pass a list of the outputs you want to publish.
"""
super().__init__()
self.function = REGISTERED_FUNCTIONS[func]
self.outputs = outputs
self.inputs = inputs or {}
self.params = params or {}
if isinstance(publish_outputs, bool):
self.publish_outputs = self.outputs if publish_outputs else []
else:
self.publish_outputs = publish_outputs

def run( # type: ignore
self,
Expand Down Expand Up @@ -404,7 +429,7 @@ def run( # type: ignore
results = {}
for output_key, output_value in zip(self.outputs, output_values):
invocation_context[output_key] = output_value
if output_key in ["query", "file_paths", "labels", "documents", "meta"]:
if output_key in self.publish_outputs:
results[output_key] = output_value
results["invocation_context"] = invocation_context

Expand Down
54 changes: 54 additions & 0 deletions test/nodes/test_shaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import haystack
from haystack import Pipeline, Document, Answer
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.nodes.other.shaper import Shaper
from haystack.nodes.retriever.sparse import BM25Retriever


@pytest.fixture
Expand Down Expand Up @@ -340,6 +342,37 @@ def test_join_documents():
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
)
assert results["invocation_context"]["documents"] == [Document(content="first | second | third")]
assert results["documents"] == [Document(content="first | second | third")]


def test_join_documents_without_publish_outputs():
shaper = Shaper(
func="join_documents",
inputs={"documents": "documents"},
params={"delimiter": " | "},
outputs=["documents"],
publish_outputs=False,
)
results, _ = shaper.run(
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
)
assert results["invocation_context"]["documents"] == [Document(content="first | second | third")]
assert "documents" not in results


def test_join_documents_with_publish_outputs_as_list():
shaper = Shaper(
func="join_documents",
inputs={"documents": "documents"},
params={"delimiter": " | "},
outputs=["documents"],
publish_outputs=["documents"],
)
results, _ = shaper.run(
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
)
assert results["invocation_context"]["documents"] == [Document(content="first | second | third")]
assert results["documents"] == [Document(content="first | second | third")]


def test_join_documents_default_delimiter():
Expand Down Expand Up @@ -457,6 +490,11 @@ def test_strings_to_answers_yaml(tmp_path):
Answer(answer="b", type="generative"),
Answer(answer="c", type="generative"),
]
assert result["answers"] == [
Answer(answer="a", type="generative"),
Answer(answer="b", type="generative"),
Answer(answer="c", type="generative"),
]


#
Expand Down Expand Up @@ -1116,3 +1154,19 @@ def test_join_query_and_documents_convert_into_documents_yaml(tmp_path):
assert result["invocation_context"]["query_and_docs"]
assert len(result["invocation_context"]["query_and_docs"]) == 4
assert isinstance(result["invocation_context"]["query_and_docs"][0], Document)


def test_shaper_publishes_unknown_arg_does_not_break_pipeline():
documents = [Document(content="test query")]
shaper = Shaper(func="rename", inputs={"value": "query"}, outputs=["unknown_by_retriever"], publish_outputs=True)
document_store = InMemoryDocumentStore(use_bm25=True)
document_store.write_documents(documents)
retriever = BM25Retriever(document_store=document_store)
pipeline = Pipeline()
pipeline.add_node(component=shaper, name="shaper", inputs=["Query"])
pipeline.add_node(component=retriever, name="retriever", inputs=["shaper"])

result = pipeline.run(query="test query")
assert result["invocation_context"]["unknown_by_retriever"] == "test query"
assert result["unknown_by_retriever"] == "test query"
assert len(result["documents"]) == 1

0 comments on commit 27ece02

Please sign in to comment.