-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Enhance
Pipeline.draw()
to show image directly in Jupyter not…
…ebook (#6961) * Enhance Pipeline.draw() to show image directly in Jupyter notebook * Add util method to check if we're in a Jupyter notebook * Split Pipeline.draw() in two methods * Update tests * Update releasenotes
- Loading branch information
1 parent
d2497d5
commit a7f36fd
Showing
16 changed files
with
195 additions
and
252 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,50 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import logging | ||
import base64 | ||
import logging | ||
|
||
import requests | ||
import networkx # type:ignore | ||
import requests | ||
|
||
from haystack.core.errors import PipelineDrawingError | ||
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs | ||
from haystack.core.type_utils import _type_name | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph: | ||
""" | ||
Add some extra nodes to show the inputs and outputs of the pipeline. | ||
Also adds labels to edges. | ||
""" | ||
# Label the edges | ||
for inp, outp, key, data in graph.edges(keys=True, data=True): | ||
data[ | ||
"label" | ||
] = f"{data['from_socket'].name} -> {data['to_socket'].name}{' (opt.)' if not data['mandatory'] else ''}" | ||
graph.add_edge(inp, outp, key=key, **data) | ||
|
||
# Add inputs fake node | ||
graph.add_node("input") | ||
for node, in_sockets in find_pipeline_inputs(graph).items(): | ||
for in_socket in in_sockets: | ||
if not in_socket.senders and in_socket.is_mandatory: | ||
# If this socket has no sender it could be a socket that receives input | ||
# directly when running the Pipeline. We can't know that for sure, in doubt | ||
# we draw it as receiving input directly. | ||
graph.add_edge("input", node, label=in_socket.name, conn_type=_type_name(in_socket.type)) | ||
|
||
# Add outputs fake node | ||
graph.add_node("output") | ||
for node, out_sockets in find_pipeline_outputs(graph).items(): | ||
for out_socket in out_sockets: | ||
graph.add_edge(node, "output", label=out_socket.name, conn_type=_type_name(out_socket.type)) | ||
|
||
return graph | ||
|
||
|
||
ARROWTAIL_MANDATORY = "--" | ||
ARROWTAIL_OPTIONAL = "-." | ||
ARROWHEAD_MANDATORY = "-->" | ||
|
@@ -31,6 +64,8 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph): | |
""" | ||
Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access. | ||
""" | ||
# Copy the graph to avoid modifying the original | ||
graph = _prepare_for_drawing(graph.copy()) | ||
graph_styled = _to_mermaid_text(graph=graph) | ||
|
||
graphbytes = graph_styled.encode("ascii") | ||
|
@@ -63,6 +98,8 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str: | |
Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation | ||
with `mermaid` codeblocks and it will be automatically rendered. | ||
""" | ||
# Copy the graph to avoid modifying the original | ||
graph = _prepare_for_drawing(graph.copy()) | ||
sockets = { | ||
comp: "".join( | ||
[ | ||
|
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,19 @@ | ||
from haystack.utils.expit import expit | ||
from haystack.utils.requests_utils import request_with_retry | ||
from haystack.utils.filters import document_matches_filter | ||
from haystack.utils.device import ComponentDevice, DeviceType, Device, DeviceMap | ||
from haystack.utils.auth import Secret, deserialize_secrets_inplace | ||
from .auth import Secret, deserialize_secrets_inplace | ||
from .device import ComponentDevice, Device, DeviceMap, DeviceType | ||
from .expit import expit | ||
from .filters import document_matches_filter | ||
from .jupyter import is_in_jupyter | ||
from .requests_utils import request_with_retry | ||
|
||
__all__ = [ | ||
"Secret", | ||
"deserialize_secrets_inplace", | ||
"ComponentDevice", | ||
"Device", | ||
"DeviceMap", | ||
"DeviceType", | ||
"expit", | ||
"document_matches_filter", | ||
"is_in_jupyter", | ||
"request_with_retry", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
def is_in_jupyter() -> bool: | ||
""" | ||
Utility function to easily check if we are in a Jupyter or Google Colab environment. | ||
Inspired by: | ||
https://github.com/explosion/spaCy/blob/e1249d3722765aaca56f538e830add7014d20e2a/spacy/util.py#L1079 | ||
Returns True if in Jupyter or Google Colab, False otherwise | ||
""" | ||
# | ||
# | ||
try: | ||
# We don't need to import `get_ipython` as it's always present in Jupyter notebooks | ||
if get_ipython().__class__.__name__ == "ZMQInteractiveShell": # type: ignore[name-defined] | ||
return True # Jupyter notebook or qtconsole | ||
if get_ipython().__class__.__module__ == "google.colab._shell": # type: ignore[name-defined] | ||
return True # Colab notebook | ||
except NameError: | ||
pass # Probably standard Python interpreter | ||
return False |
7 changes: 7 additions & 0 deletions
7
releasenotes/notes/enhance-pipeline-draw-5fe3131db71f6f54.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
--- | ||
enhancements: | ||
- | | ||
Add new `Pipeline.show()` method to generated image inline if run in a Jupyter notebook. | ||
If called outside a notebook it will raise a `PipelineDrawingError`. | ||
`Pipeline.draw()` has also been simplified and the `engine` argument has been removed. | ||
Now all images will be generated using Mermaid. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,9 @@ | ||
from pathlib import Path | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
from unittest.mock import patch, MagicMock | ||
|
||
|
||
@pytest.fixture | ||
def test_files(): | ||
return Path(__file__).parent / "test_files" | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def mock_mermaid_request(test_files): | ||
""" | ||
Prevents real requests to https://mermaid.ink/ | ||
""" | ||
with patch("haystack.core.pipeline.draw.mermaid.requests.get") as mock_get: | ||
mock_response = MagicMock() | ||
mock_response.status_code = 200 | ||
mock_response.content = open(test_files / "mermaid_mock" / "test_response.png", "rb").read() | ||
mock_get.return_value = mock_response | ||
yield |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,12 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import logging | ||
from pathlib import Path | ||
|
||
from haystack.core.component import component | ||
from haystack.core.pipeline import Pipeline | ||
|
||
import logging | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
|
||
|
@@ -18,10 +17,9 @@ def run(self, a: int, b: int = 2): | |
return {"c": a + b} | ||
|
||
|
||
def test_pipeline(tmp_path): | ||
def test_pipeline(): | ||
pipeline = Pipeline() | ||
pipeline.add_component("with_defaults", WithDefault()) | ||
pipeline.draw(tmp_path / "default_value.png") | ||
|
||
# Pass all the inputs | ||
results = pipeline.run({"with_defaults": {"a": 40, "b": 30}}) | ||
|
Oops, something went wrong.