Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance Pipeline.draw() to show image directly in Jupyter notebook #6961

Merged
merged 5 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Split Pipeline.draw() in two methods
  • Loading branch information
silvanocerza committed Feb 9, 2024
commit b98b4e7199b53d91cf0aa7167c49145994da9f74
39 changes: 4 additions & 35 deletions haystack/core/pipeline/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# SPDX-License-Identifier: Apache-2.0
import base64
import logging
from pathlib import Path
from typing import Optional

import networkx # type:ignore
import requests
Expand All @@ -16,39 +14,6 @@
logger = logging.getLogger(__name__)


def _draw(graph: networkx.MultiDiGraph, path: Optional[Path] = None) -> None:
"""
Draw a pipeline graph using Mermaid and save it to a file.
If on a Jupyter notebook, it will also display the image inline.
"""
image_data = _to_mermaid_image(_prepare_for_drawing(graph))

in_notebook = False
try:
from IPython.core.getipython import get_ipython
from IPython.display import Image, display

if "IPKernelApp" in get_ipython().config:
# We're in a notebook, let's display the image
display(Image(image_data))
in_notebook = True
except ImportError:
pass
except AttributeError:
pass

if not in_notebook and not path:
# We're not in a notebook and no path is given, the user must have forgot
# to specify the path. Raise an error.
msg = "No path specified to save the image to."
raise ValueError(msg)

if path:
# If we reached this point we're in a notebook and the user has specified a path.
# Let's save the image anyway even if it's been displayed in the notebook.
Path(path).write_bytes(image_data)


def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
"""
Add some extra nodes to show the inputs and outputs of the pipeline.
Expand Down Expand Up @@ -99,6 +64,8 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph):
"""
Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access.
"""
# Copy the graph to avoid modifying the original
graph = _prepare_for_drawing(graph.copy())
graph_styled = _to_mermaid_text(graph=graph)

graphbytes = graph_styled.encode("ascii")
Expand Down Expand Up @@ -131,6 +98,8 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation
with `mermaid` codeblocks and it will be automatically rendered.
"""
# Copy the graph to avoid modifying the original
graph = _prepare_for_drawing(graph.copy())
sockets = {
comp: "".join(
[
Expand Down
35 changes: 27 additions & 8 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
import networkx # type:ignore

from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.errors import PipelineConnectError, PipelineError, PipelineRuntimeError, PipelineValidationError
from haystack.core.errors import (
PipelineConnectError,
PipelineDrawingError,
PipelineError,
PipelineRuntimeError,
PipelineValidationError,
)
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.core.type_utils import _type_name, _types_are_compatible
from haystack.utils import is_in_jupyter

from .descriptions import find_pipeline_inputs, find_pipeline_outputs
from .draw import _draw
from .draw import _to_mermaid_image

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -442,17 +449,29 @@ def outputs(self) -> Dict[str, Dict[str, Any]]:
}
return outputs

def draw(self, path: Optional[Path] = None) -> None:
def show(self) -> None:
"""
Save a Pipeline image to `path`.
If `path` is `None` the image will be displayed inline in the Jupyter notebook.
If `path` is `None` and the code is not running in a Jupyter notebook, an error will be raised.
If running in a Jupyter notebook, display an image representing this `Pipeline`.

If `path` is given it will always be saved to file, whether it's a notebook or not.
"""
if is_in_jupyter():
from IPython.display import Image, display

image_data = _to_mermaid_image(self.graph)

display(Image(image_data))
else:
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
raise PipelineDrawingError(msg)

def draw(self, path: Path) -> None:
"""
Save an image representing this `Pipeline` to `path`.
"""
# Before drawing we edit a bit the graph, to avoid modifying the original that is
# used for running the pipeline we copy it.
_draw(graph=self.graph.copy(), path=path)
image_data = _to_mermaid_image(self.graph)
Path(path).write_bytes(image_data)

def warm_up(self):
"""
Expand Down