Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add option to not override results by Shaper #4231

Merged
merged 4 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
support list
  • Loading branch information
tstadel committed Feb 22, 2023
commit 5b3f5792bb97dd0f2c46206b803419eed91379e6
16 changes: 10 additions & 6 deletions haystack/nodes/other/shaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(
outputs: List[str],
inputs: Optional[Dict[str, Union[List[str], str]]] = None,
params: Optional[Dict[str, Any]] = None,
publish_outputs: bool = True,
publish_outputs: Union[bool, List[str]] = True,
):
"""
Initializes the Shaper component.
Expand Down Expand Up @@ -320,9 +320,10 @@ def __init__(
You can use params to provide fallback values for arguments of `run` that you're not sure exist.
So if you need `query` to exist, you can provide a fallback value in the params, which will be used only if `query`
is not passed to this node by the pipeline.
:param outputs: THe key to store the outputs in the invocation context. The length of the outputs must match
:param outputs: The key to store the outputs in the invocation context. The length of the outputs must match
the number of outputs produced by the function invoked.
:param publish_outputs: Whether to publish the outputs to the pipeline's output. Defaults to True.
:param publish_outputs: Controls whether to publish the outputs to the pipeline's output.
Set `True` (default value) to publishes all outputs or `False` to publish None.
E.g. if `outputs = ["documents"]` result for `publish_outputs = True` looks like
```python
{
Expand All @@ -340,14 +341,17 @@ def __init__(
},
}
```
Note that only outputs ["query", "file_paths", "labels", "documents", "meta", "answers"] can be published.
If you want to have finer-grained control, pass a list of the outputs you want to publish.
"""
super().__init__()
self.function = REGISTERED_FUNCTIONS[func]
self.outputs = outputs
self.publish_outputs = publish_outputs
self.inputs = inputs or {}
self.params = params or {}
if isinstance(publish_outputs, bool):
self.publish_outputs = self.outputs if publish_outputs else []
else:
self.publish_outputs = publish_outputs

def run( # type: ignore
self,
Expand Down Expand Up @@ -425,7 +429,7 @@ def run( # type: ignore
results = {}
for output_key, output_value in zip(self.outputs, output_values):
invocation_context[output_key] = output_value
if self.publish_outputs:
if output_key in self.publish_outputs:
results[output_key] = output_value
results["invocation_context"] = invocation_context

Expand Down
15 changes: 15 additions & 0 deletions test/nodes/test_shaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,21 @@ def test_join_documents_without_publish_outputs():
assert "documents" not in results


def test_join_documents_with_publish_outputs_as_list():
shaper = Shaper(
func="join_documents",
inputs={"documents": "documents"},
params={"delimiter": " | "},
outputs=["documents"],
publish_outputs=["documents"],
)
results, _ = shaper.run(
documents=[Document(content="first"), Document(content="second"), Document(content="third")]
)
assert results["invocation_context"]["documents"] == [Document(content="first | second | third")]
assert results["documents"] == [Document(content="first | second | third")]


def test_join_documents_default_delimiter():
shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
results, _ = shaper.run(
Expand Down