Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance Pipeline.draw() to show image directly in Jupyter notebook #6961

Merged
merged 5 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,17 +1,50 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import logging
import base64
import logging

import requests
import networkx # type:ignore
import requests

from haystack.core.errors import PipelineDrawingError
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
from haystack.core.type_utils import _type_name

logger = logging.getLogger(__name__)


def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
"""
Add some extra nodes to show the inputs and outputs of the pipeline.
Also adds labels to edges.
"""
# Label the edges
for inp, outp, key, data in graph.edges(keys=True, data=True):
data[
"label"
] = f"{data['from_socket'].name} -> {data['to_socket'].name}{' (opt.)' if not data['mandatory'] else ''}"
graph.add_edge(inp, outp, key=key, **data)

# Add inputs fake node
graph.add_node("input")
for node, in_sockets in find_pipeline_inputs(graph).items():
for in_socket in in_sockets:
if not in_socket.senders and in_socket.is_mandatory:
# If this socket has no sender it could be a socket that receives input
# directly when running the Pipeline. We can't know that for sure, in doubt
# we draw it as receiving input directly.
graph.add_edge("input", node, label=in_socket.name, conn_type=_type_name(in_socket.type))

# Add outputs fake node
graph.add_node("output")
for node, out_sockets in find_pipeline_outputs(graph).items():
for out_socket in out_sockets:
graph.add_edge(node, "output", label=out_socket.name, conn_type=_type_name(out_socket.type))

return graph


ARROWTAIL_MANDATORY = "--"
ARROWTAIL_OPTIONAL = "-."
ARROWHEAD_MANDATORY = "-->"
Expand All @@ -31,6 +64,8 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph):
"""
Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access.
"""
# Copy the graph to avoid modifying the original
graph = _prepare_for_drawing(graph.copy())
graph_styled = _to_mermaid_text(graph=graph)

graphbytes = graph_styled.encode("ascii")
Expand Down Expand Up @@ -63,6 +98,8 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation
with `mermaid` codeblocks and it will be automatically rendered.
"""
# Copy the graph to avoid modifying the original
graph = _prepare_for_drawing(graph.copy())
sockets = {
comp: "".join(
[
Expand Down
3 changes: 0 additions & 3 deletions haystack/core/pipeline/draw/__init__.py

This file was deleted.

100 changes: 0 additions & 100 deletions haystack/core/pipeline/draw/draw.py

This file was deleted.

41 changes: 0 additions & 41 deletions haystack/core/pipeline/draw/graphviz.py

This file was deleted.

45 changes: 29 additions & 16 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,19 @@
import networkx # type:ignore

from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.errors import PipelineConnectError, PipelineError, PipelineRuntimeError, PipelineValidationError
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
from haystack.core.pipeline.draw.draw import RenderingEngines, _draw
from haystack.core.errors import (
PipelineConnectError,
PipelineDrawingError,
PipelineError,
PipelineRuntimeError,
PipelineValidationError,
)
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.core.type_utils import _type_name, _types_are_compatible
from haystack.utils import is_in_jupyter

from .descriptions import find_pipeline_inputs, find_pipeline_outputs
from .draw import _to_mermaid_image

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -441,24 +449,29 @@ def outputs(self) -> Dict[str, Dict[str, Any]]:
}
return outputs

def draw(self, path: Path, engine: RenderingEngines = "mermaid-image") -> None:
def show(self) -> None:
"""
Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid.
Run `pip install graphviz` or `pip install mermaid` to install missing dependencies.
If running in a Jupyter notebook, display an image representing this `Pipeline`.
Args:
path: where to save the diagram.
engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-image'.
Default is 'mermaid-image'.
"""
if is_in_jupyter():
from IPython.display import Image, display

Returns:
None
image_data = _to_mermaid_image(self.graph)

Raises:
ImportError: if `engine='graphviz'` and `pygraphviz` is not installed.
HTTPConnectionError: (and similar) if the internet connection is down or other connection issues.
display(Image(image_data))
else:
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
raise PipelineDrawingError(msg)

def draw(self, path: Path) -> None:
"""
Save an image representing this `Pipeline` to `path`.
"""
_draw(graph=networkx.MultiDiGraph(self.graph), path=path, engine=engine)
# Before drawing we edit a bit the graph, to avoid modifying the original that is
# used for running the pipeline we copy it.
image_data = _to_mermaid_image(self.graph)
Path(path).write_bytes(image_data)

def warm_up(self):
"""
Expand Down
24 changes: 19 additions & 5 deletions haystack/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
from haystack.utils.expit import expit
from haystack.utils.requests_utils import request_with_retry
from haystack.utils.filters import document_matches_filter
from haystack.utils.device import ComponentDevice, DeviceType, Device, DeviceMap
from haystack.utils.auth import Secret, deserialize_secrets_inplace
from .auth import Secret, deserialize_secrets_inplace
from .device import ComponentDevice, Device, DeviceMap, DeviceType
from .expit import expit
from .filters import document_matches_filter
from .jupyter import is_in_jupyter
from .requests_utils import request_with_retry

__all__ = [
"Secret",
"deserialize_secrets_inplace",
"ComponentDevice",
"Device",
"DeviceMap",
"DeviceType",
"expit",
"document_matches_filter",
"is_in_jupyter",
"request_with_retry",
]
25 changes: 25 additions & 0 deletions haystack/utils/jupyter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0


def is_in_jupyter() -> bool:
"""
Utility function to easily check if we are in a Jupyter or Google Colab environment.
Inspired by:
https://github.com/explosion/spaCy/blob/e1249d3722765aaca56f538e830add7014d20e2a/spacy/util.py#L1079
Returns True if in Jupyter or Google Colab, False otherwise
"""
#
#
try:
# We don't need to import `get_ipython` as it's always present in Jupyter notebooks
if get_ipython().__class__.__name__ == "ZMQInteractiveShell": # type: ignore[name-defined]
return True # Jupyter notebook or qtconsole
if get_ipython().__class__.__module__ == "google.colab._shell": # type: ignore[name-defined]
return True # Colab notebook
except NameError:
pass # Probably standard Python interpreter
return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
enhancements:
- |
Add new `Pipeline.show()` method to generated image inline if run in a Jupyter notebook.
If called outside a notebook it will raise a `PipelineDrawingError`.
`Pipeline.draw()` has also been simplified and the `engine` argument has been removed.
Now all images will be generated using Mermaid.
16 changes: 1 addition & 15 deletions test/core/conftest.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,9 @@
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest

from unittest.mock import patch, MagicMock


@pytest.fixture
def test_files():
return Path(__file__).parent / "test_files"


@pytest.fixture(autouse=True)
def mock_mermaid_request(test_files):
"""
Prevents real requests to https://mermaid.ink/
"""
with patch("haystack.core.pipeline.draw.mermaid.requests.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = open(test_files / "mermaid_mock" / "test_response.png", "rb").read()
mock_get.return_value = mock_response
yield
6 changes: 2 additions & 4 deletions test/core/pipeline/test_default_value.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import logging
from pathlib import Path

from haystack.core.component import component
from haystack.core.pipeline import Pipeline

import logging

logging.basicConfig(level=logging.DEBUG)


Expand All @@ -18,10 +17,9 @@ def run(self, a: int, b: int = 2):
return {"c": a + b}


def test_pipeline(tmp_path):
def test_pipeline():
pipeline = Pipeline()
pipeline.add_component("with_defaults", WithDefault())
pipeline.draw(tmp_path / "default_value.png")

# Pass all the inputs
results = pipeline.run({"with_defaults": {"a": 40, "b": 30}})
Expand Down
Loading