Skip to content

Commit

Permalink
feat: Enhance Pipeline.draw() to show image directly in Jupyter not…
Browse files Browse the repository at this point in the history
…ebook (#6961)

* Enhance Pipeline.draw() to show image directly in Jupyter notebook

* Add util method to check if we're in a Jupyter notebook

* Split Pipeline.draw() in two methods

* Update tests

* Update releasenotes
  • Loading branch information
silvanocerza committed Feb 9, 2024
1 parent d2497d5 commit a7f36fd
Show file tree
Hide file tree
Showing 16 changed files with 195 additions and 252 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,50 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import logging
import base64
import logging

import requests
import networkx # type:ignore
import requests

from haystack.core.errors import PipelineDrawingError
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
from haystack.core.type_utils import _type_name

logger = logging.getLogger(__name__)


def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
"""
Add some extra nodes to show the inputs and outputs of the pipeline.
Also adds labels to edges.
"""
# Label the edges
for inp, outp, key, data in graph.edges(keys=True, data=True):
data[
"label"
] = f"{data['from_socket'].name} -> {data['to_socket'].name}{' (opt.)' if not data['mandatory'] else ''}"
graph.add_edge(inp, outp, key=key, **data)

# Add inputs fake node
graph.add_node("input")
for node, in_sockets in find_pipeline_inputs(graph).items():
for in_socket in in_sockets:
if not in_socket.senders and in_socket.is_mandatory:
# If this socket has no sender it could be a socket that receives input
# directly when running the Pipeline. We can't know that for sure, in doubt
# we draw it as receiving input directly.
graph.add_edge("input", node, label=in_socket.name, conn_type=_type_name(in_socket.type))

# Add outputs fake node
graph.add_node("output")
for node, out_sockets in find_pipeline_outputs(graph).items():
for out_socket in out_sockets:
graph.add_edge(node, "output", label=out_socket.name, conn_type=_type_name(out_socket.type))

return graph


ARROWTAIL_MANDATORY = "--"
ARROWTAIL_OPTIONAL = "-."
ARROWHEAD_MANDATORY = "-->"
Expand All @@ -31,6 +64,8 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph):
"""
Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access.
"""
# Copy the graph to avoid modifying the original
graph = _prepare_for_drawing(graph.copy())
graph_styled = _to_mermaid_text(graph=graph)

graphbytes = graph_styled.encode("ascii")
Expand Down Expand Up @@ -63,6 +98,8 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation
with `mermaid` codeblocks and it will be automatically rendered.
"""
# Copy the graph to avoid modifying the original
graph = _prepare_for_drawing(graph.copy())
sockets = {
comp: "".join(
[
Expand Down
3 changes: 0 additions & 3 deletions haystack/core/pipeline/draw/__init__.py

This file was deleted.

100 changes: 0 additions & 100 deletions haystack/core/pipeline/draw/draw.py

This file was deleted.

41 changes: 0 additions & 41 deletions haystack/core/pipeline/draw/graphviz.py

This file was deleted.

45 changes: 29 additions & 16 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,19 @@
import networkx # type:ignore

from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.errors import PipelineConnectError, PipelineError, PipelineRuntimeError, PipelineValidationError
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
from haystack.core.pipeline.draw.draw import RenderingEngines, _draw
from haystack.core.errors import (
PipelineConnectError,
PipelineDrawingError,
PipelineError,
PipelineRuntimeError,
PipelineValidationError,
)
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.core.type_utils import _type_name, _types_are_compatible
from haystack.utils import is_in_jupyter

from .descriptions import find_pipeline_inputs, find_pipeline_outputs
from .draw import _to_mermaid_image

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -441,24 +449,29 @@ def outputs(self) -> Dict[str, Dict[str, Any]]:
}
return outputs

def draw(self, path: Path, engine: RenderingEngines = "mermaid-image") -> None:
def show(self) -> None:
"""
Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid.
Run `pip install graphviz` or `pip install mermaid` to install missing dependencies.
If running in a Jupyter notebook, display an image representing this `Pipeline`.
Args:
path: where to save the diagram.
engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-image'.
Default is 'mermaid-image'.
"""
if is_in_jupyter():
from IPython.display import Image, display

Returns:
None
image_data = _to_mermaid_image(self.graph)

Raises:
ImportError: if `engine='graphviz'` and `pygraphviz` is not installed.
HTTPConnectionError: (and similar) if the internet connection is down or other connection issues.
display(Image(image_data))
else:
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
raise PipelineDrawingError(msg)

def draw(self, path: Path) -> None:
"""
Save an image representing this `Pipeline` to `path`.
"""
_draw(graph=networkx.MultiDiGraph(self.graph), path=path, engine=engine)
# Before drawing we edit a bit the graph, to avoid modifying the original that is
# used for running the pipeline we copy it.
image_data = _to_mermaid_image(self.graph)
Path(path).write_bytes(image_data)

def warm_up(self):
"""
Expand Down
24 changes: 19 additions & 5 deletions haystack/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
from haystack.utils.expit import expit
from haystack.utils.requests_utils import request_with_retry
from haystack.utils.filters import document_matches_filter
from haystack.utils.device import ComponentDevice, DeviceType, Device, DeviceMap
from haystack.utils.auth import Secret, deserialize_secrets_inplace
from .auth import Secret, deserialize_secrets_inplace
from .device import ComponentDevice, Device, DeviceMap, DeviceType
from .expit import expit
from .filters import document_matches_filter
from .jupyter import is_in_jupyter
from .requests_utils import request_with_retry

__all__ = [
"Secret",
"deserialize_secrets_inplace",
"ComponentDevice",
"Device",
"DeviceMap",
"DeviceType",
"expit",
"document_matches_filter",
"is_in_jupyter",
"request_with_retry",
]
25 changes: 25 additions & 0 deletions haystack/utils/jupyter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0


def is_in_jupyter() -> bool:
"""
Utility function to easily check if we are in a Jupyter or Google Colab environment.
Inspired by:
https://github.com/explosion/spaCy/blob/e1249d3722765aaca56f538e830add7014d20e2a/spacy/util.py#L1079
Returns True if in Jupyter or Google Colab, False otherwise
"""
#
#
try:
# We don't need to import `get_ipython` as it's always present in Jupyter notebooks
if get_ipython().__class__.__name__ == "ZMQInteractiveShell": # type: ignore[name-defined]
return True # Jupyter notebook or qtconsole
if get_ipython().__class__.__module__ == "google.colab._shell": # type: ignore[name-defined]
return True # Colab notebook
except NameError:
pass # Probably standard Python interpreter
return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
enhancements:
- |
Add new `Pipeline.show()` method to generated image inline if run in a Jupyter notebook.
If called outside a notebook it will raise a `PipelineDrawingError`.
`Pipeline.draw()` has also been simplified and the `engine` argument has been removed.
Now all images will be generated using Mermaid.
16 changes: 1 addition & 15 deletions test/core/conftest.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,9 @@
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest

from unittest.mock import patch, MagicMock


@pytest.fixture
def test_files():
return Path(__file__).parent / "test_files"


@pytest.fixture(autouse=True)
def mock_mermaid_request(test_files):
"""
Prevents real requests to https://mermaid.ink/
"""
with patch("haystack.core.pipeline.draw.mermaid.requests.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = open(test_files / "mermaid_mock" / "test_response.png", "rb").read()
mock_get.return_value = mock_response
yield
6 changes: 2 additions & 4 deletions test/core/pipeline/test_default_value.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import logging
from pathlib import Path

from haystack.core.component import component
from haystack.core.pipeline import Pipeline

import logging

logging.basicConfig(level=logging.DEBUG)


Expand All @@ -18,10 +17,9 @@ def run(self, a: int, b: int = 2):
return {"c": a + b}


def test_pipeline(tmp_path):
def test_pipeline():
pipeline = Pipeline()
pipeline.add_component("with_defaults", WithDefault())
pipeline.draw(tmp_path / "default_value.png")

# Pass all the inputs
results = pipeline.run({"with_defaults": {"a": 40, "b": 30}})
Expand Down
Loading

0 comments on commit a7f36fd

Please sign in to comment.