Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance Pipeline.draw() to show image directly in Jupyter notebook #6961

Merged
merged 5 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Enhance Pipeline.draw() to show image directly in Jupyter notebook
  • Loading branch information
silvanocerza committed Feb 8, 2024
commit 1c64b80f321c80ca73cab025c78db652119e2f51
Original file line number Diff line number Diff line change
@@ -1,17 +1,85 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import logging
import base64
import logging
from pathlib import Path
from typing import Optional

import requests
import networkx # type:ignore
import requests

from haystack.core.errors import PipelineDrawingError
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
from haystack.core.type_utils import _type_name

logger = logging.getLogger(__name__)


def _draw(graph: networkx.MultiDiGraph, path: Optional[Path] = None) -> None:
"""
Draw a pipeline graph using Mermaid and save it to a file.
If on a Jupyter notebook, it will also display the image inline.
"""
image_data = _to_mermaid_image(_prepare_for_drawing(graph))

in_notebook = False
try:
from IPython.core.getipython import get_ipython
from IPython.display import Image, display

if "IPKernelApp" in get_ipython().config:
# We're in a notebook, let's display the image
display(Image(image_data))
in_notebook = True
except ImportError:
pass
except AttributeError:
pass

if not in_notebook and not path:
# We're not in a notebook and no path is given, the user must have forgot
# to specify the path. Raise an error.
msg = "No path specified to save the image to."
raise ValueError(msg)

if path:
# If we reached this point we're in a notebook and the user has specified a path.
# Let's save the image anyway even if it's been displayed in the notebook.
Path(path).write_bytes(image_data)
silvanocerza marked this conversation as resolved.
Show resolved Hide resolved


def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
"""
Add some extra nodes to show the inputs and outputs of the pipeline.
Also adds labels to edges.
"""
# Label the edges
for inp, outp, key, data in graph.edges(keys=True, data=True):
data[
"label"
] = f"{data['from_socket'].name} -> {data['to_socket'].name}{' (opt.)' if not data['mandatory'] else ''}"
graph.add_edge(inp, outp, key=key, **data)

# Add inputs fake node
graph.add_node("input")
for node, in_sockets in find_pipeline_inputs(graph).items():
for in_socket in in_sockets:
if not in_socket.senders and in_socket.is_mandatory:
# If this socket has no sender it could be a socket that receives input
# directly when running the Pipeline. We can't know that for sure, in doubt
# we draw it as receiving input directly.
graph.add_edge("input", node, label=in_socket.name, conn_type=_type_name(in_socket.type))

# Add outputs fake node
graph.add_node("output")
for node, out_sockets in find_pipeline_outputs(graph).items():
for out_socket in out_sockets:
graph.add_edge(node, "output", label=out_socket.name, conn_type=_type_name(out_socket.type))

return graph


ARROWTAIL_MANDATORY = "--"
ARROWTAIL_OPTIONAL = "-."
ARROWHEAD_MANDATORY = "-->"
Expand Down
3 changes: 0 additions & 3 deletions haystack/core/pipeline/draw/__init__.py

This file was deleted.

100 changes: 0 additions & 100 deletions haystack/core/pipeline/draw/draw.py

This file was deleted.

41 changes: 0 additions & 41 deletions haystack/core/pipeline/draw/graphviz.py

This file was deleted.

28 changes: 11 additions & 17 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.errors import PipelineConnectError, PipelineError, PipelineRuntimeError, PipelineValidationError
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
from haystack.core.pipeline.draw.draw import RenderingEngines, _draw
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.core.type_utils import _type_name, _types_are_compatible

from .descriptions import find_pipeline_inputs, find_pipeline_outputs
from .draw import _draw

logger = logging.getLogger(__name__)

# We use a generic type to annotate the return value of classmethods,
Expand Down Expand Up @@ -441,24 +442,17 @@ def outputs(self) -> Dict[str, Dict[str, Any]]:
}
return outputs

def draw(self, path: Path, engine: RenderingEngines = "mermaid-image") -> None:
def draw(self, path: Optional[Path] = None) -> None:
"""
Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid.
Run `pip install graphviz` or `pip install mermaid` to install missing dependencies.

Args:
path: where to save the diagram.
engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-image'.
Default is 'mermaid-image'.
Save a Pipeline image to `path`.
If `path` is `None` the image will be displayed inline in the Jupyter notebook.
If `path` is `None` and the code is not running in a Jupyter notebook, an error will be raised.

Returns:
None

Raises:
ImportError: if `engine='graphviz'` and `pygraphviz` is not installed.
HTTPConnectionError: (and similar) if the internet connection is down or other connection issues.
If `path` is given it will always be saved to file, whether it's a notebook or not.
"""
_draw(graph=networkx.MultiDiGraph(self.graph), path=path, engine=engine)
# Before drawing we edit a bit the graph, to avoid modifying the original that is
# used for running the pipeline we copy it.
_draw(graph=self.graph.copy(), path=path)

def warm_up(self):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
`Pipeline.draw()` will now show the generated image inline if run in a Jupyter notebook.
It can also be called without a path to save the image if in a notebook. In all other cases
it will raise a `ValueError`.
5 changes: 2 additions & 3 deletions test/core/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest

from unittest.mock import patch, MagicMock


@pytest.fixture
def test_files():
Expand All @@ -15,7 +14,7 @@ def mock_mermaid_request(test_files):
"""
Prevents real requests to https://mermaid.ink/
"""
with patch("haystack.core.pipeline.draw.mermaid.requests.get") as mock_get:
with patch("haystack.core.pipeline.draw.requests.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = open(test_files / "mermaid_mock" / "test_response.png", "rb").read()
Expand Down
Loading