Skip to content

Commit

Permalink
Refactor public Treescope API and set up packaging.
Browse files Browse the repository at this point in the history
- Moves top-level API functions to be defined in treescope._internal.api
  and re-exported in treescope, instead of being in two places.
- Renames `treescope.renderer` to `treescope.renderers`.
- Sets up Python packaging and test runners.
- Renames custom autovisualization tag type.
- Removes support for labels above or below integer digitboxes, which do not
  render correctly.
- Simplifies global configuration by removing "interactive context" stack.

PiperOrigin-RevId: 653641608
Change-Id: I1668950281954ce85e262a867a9ecb1f9b573459
  • Loading branch information
danieldjohnson committed Jul 18, 2024
1 parent adf28ff commit 13eede1
Show file tree
Hide file tree
Showing 32 changed files with 433 additions and 319 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Unittests

# Allow to trigger the workflow manually (e.g. when deps changes)
on: [push, workflow_dispatch]

jobs:
unittest-job:
runs-on: ubuntu-latest
timeout-minutes: 30

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

steps:
- uses: actions/checkout@v3

# Install deps
- uses: actions/setup-python@v4
with:
python-version: 3.10.14
# Uncomment to cache of pip dependencies (if tests too slow)
# cache: pip
# cache-dependency-path: '**/pyproject.toml'

- run: pip --version
- run: pip install -e .[dev,extras]
- run: pip freeze

# Run tests
- name: Run tests
run: python run_tests.py
77 changes: 77 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
[project]
# Project metadata. Available keys are documented at:
# https://packaging.python.org/en/latest/specifications/declaring-project-metadata
name = "treescope"
description = "Treescope: An interactive HTML pretty-printer for ML research in IPython notebooks."
readme = "README.md"
requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [{name = "The Treescope Authors", email="[email protected]"}]
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"License :: OSI Approved :: Apache Software License",
"Intended Audience :: Science/Research",
]
keywords = []

# Pip dependencies of the project.
dependencies = [
"numpy>=1.25.2"
]

# This is set automatically by flit using `treescope.__version__`
dynamic = ["version"]

[project.urls]
homepage = "https://github.com/google-deepmind/treescope"
repository = "https://github.com/google-deepmind/treescope"

[project.optional-dependencies]
# Dependencies required for running tests.
test = [
"absl-py>=1.4.0",
"jax>=0.4.23",
"pytest>=8.2.2",
"torch>=2.0.0",
]
# Extra dependencies for some notebook demos.
notebook = [
"ipython",
"palettable",
"jax>=0.4.23",
]
# Development deps (linting, formating,...)
# Installed through `pip install .[dev]`
dev = [
"pylint>=2.6.0",
"pyink>=24.3.0",
"ipython",
"jupyter",
]
# Requirements for building documentation.
docs = [
"ipython",
"sphinx>=6.0.0,<7.3.0",
"sphinx-book-theme>=1.0.1",
"sphinxcontrib-katex",
"ipython>=8.8.0",
"myst-nb>=1.0.0",
"myst-parser>=3.0.1",
"matplotlib>=3.5.0",
"sphinx-collections>=0.0.1",
"sphinx_contributors",
"sphinx-hoverxref",
"jax[cpu]>=0.4.23",
]

[tool.pyink]
# Formatting configuration to follow Google style-guide
line-length = 80
unstable = true
pyink-indentation = 2
pyink-use-majority-quotes = true

[build-system]
requires = ["flit_core >=3.8,<4"]
build-backend = "flit_core.buildapi"
19 changes: 19 additions & 0 deletions run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024 The Treescope Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Entry point executable to run all tests."""

import subprocess

if __name__ == "__main__":
subprocess.check_call(["python", "-m", "pytest"])
17 changes: 7 additions & 10 deletions tests/ndarray_adapters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
import jax.numpy as jnp
import numpy as np
import torch
from treescope import array_autovisualizer
from treescope import arrayviz
from treescope import autovisualize
from treescope import default_renderer
import treescope
from treescope import ndarray_adapters
from treescope import type_registries

Expand Down Expand Up @@ -182,24 +179,24 @@ def test_array_rendering_without_error(self, array_type, dtype):
raise ValueError(f"Unsupported array_type: {array_type}")

with self.subTest("explicit_unmasked"):
res = arrayviz.render_array(array)
res = treescope.render_array(array)
self.assertTrue(hasattr(res, "_repr_html_"))

with self.subTest("explicit_masked"):
res = arrayviz.render_array(array, valid_mask=array > 100)
res = treescope.render_array(array, valid_mask=array > 100)
self.assertTrue(hasattr(res, "_repr_html_"))

with self.subTest("explicit_masked_truncated"):
res = arrayviz.render_array(
res = treescope.render_array(
array, valid_mask=array > 100, truncate=True, maximum_size=100
)
self.assertTrue(hasattr(res, "_repr_html_"))

with self.subTest("automatic"):
with autovisualize.active_autovisualizer.set_scoped(
array_autovisualizer.ArrayAutovisualizer()
with treescope.active_autovisualizer.set_scoped(
treescope.ArrayAutovisualizer()
):
res = default_renderer.render_to_html(
res = treescope.render_to_html(
array, ignore_exceptions=False, compressed=False
)
self.assertIsInstance(res, str)
Expand Down
55 changes: 27 additions & 28 deletions tests/renderer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
import jax.numpy as jnp
import numpy as np
import torch
from treescope import autovisualize
from treescope import default_renderer
import treescope
from treescope import handlers
from treescope import layout_algorithms
from treescope import lowering
Expand All @@ -49,7 +48,7 @@ def _repr_html_(self):
class TreescopeRendererTest(parameterized.TestCase):

def test_renderer_interface(self):
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()

rendering = renderer.to_text({"key": "value"})
self.assertEqual(rendering, "{'key': 'value'}")
Expand All @@ -63,10 +62,10 @@ def test_renderer_interface(self):
)

def test_high_level_interface(self):
rendering = default_renderer.render_to_text({"key": "value"})
rendering = treescope.render_to_text({"key": "value"})
self.assertEqual(rendering, "{'key': 'value'}")

rendering = default_renderer.render_to_html({"key": "value"})
rendering = treescope.render_to_html({"key": "value"})
self.assertIsInstance(rendering, str)

def test_error_recovery(self):
Expand All @@ -82,7 +81,7 @@ def hook_that_crashes(node, path, node_renderer):
raise RuntimeError("hook error!")
return NotImplemented

renderer = default_renderer.active_renderer.get().extended_with(
renderer = treescope.active_renderer.get().extended_with(
handlers=[handler_that_crashes], wrapper_hooks=[hook_that_crashes]
)

Expand Down Expand Up @@ -394,18 +393,18 @@ def hook_that_crashes(node, path, node_renderer):
),
dict(
testcase_name="well_known_function",
target=default_renderer.render_to_text,
target=treescope.render_to_text,
expected_collapsed="render_to_text",
expected_roundtrip_collapsed=(
"treescope.default_renderer.render_to_text"
"treescope.render_to_text"
),
),
dict(
testcase_name="well_known_type",
target=autovisualize.IPythonVisualization,
target=treescope.IPythonVisualization,
expected_collapsed="IPythonVisualization",
expected_roundtrip_collapsed=(
"treescope.autovisualize.IPythonVisualization"
"treescope.IPythonVisualization"
),
),
dict(
Expand Down Expand Up @@ -539,7 +538,7 @@ def test_object_rendering(
assert target is None
target = target_builder()

renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
# Render it to IR.
rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
Expand Down Expand Up @@ -591,7 +590,7 @@ def inner_fn(y):

closure = outer_fn(100)

renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
# Enable closure rendering (currently disabled by default)
renderer = renderer.extended_with(
handlers=[
Expand Down Expand Up @@ -624,7 +623,7 @@ def inner_fn(y):

def test_fallback_repr_pytree_node(self):
target = [fixture_lib.UnknownPytreeNode(1234, 5678)]
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
)
Expand Down Expand Up @@ -654,7 +653,7 @@ def test_fallback_repr_pytree_node(self):

def test_fallback_repr_one_line(self):
target = [fixture_lib.UnknownObjectWithOneLineRepr()]
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
)
Expand All @@ -674,7 +673,7 @@ def test_fallback_repr_one_line(self):

def test_fallback_repr_multiline_idiomatic(self):
target = [fixture_lib.UnknownObjectWithMultiLineRepr()]
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
)
Expand All @@ -697,7 +696,7 @@ def test_fallback_repr_multiline_idiomatic(self):

def test_fallback_repr_multiline_unidiomatic(self):
target = [fixture_lib.UnknownObjectWithBadMultiLineRepr()]
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
)
Expand All @@ -721,7 +720,7 @@ def test_fallback_repr_multiline_unidiomatic(self):

def test_fallback_repr_basic(self):
target = [fixture_lib.UnknownObjectWithBuiltinRepr()]
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
)
Expand All @@ -742,7 +741,7 @@ def test_fallback_repr_basic(self):
def test_shared_values(self):
shared = ["bar"]
target = [shared, shared, {"foo": shared}]
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
)
Expand Down Expand Up @@ -775,39 +774,39 @@ def test_autovisualizer(self):

def autovisualizer_for_test(node, path):
if isinstance(node, str):
return autovisualize.CustomTreescopeVisualization(
return treescope.VisualizationFromTreescopePart(
rendering_parts.RenderableAndLineAnnotations(
rendering_parts.text("(visualiation for foo goes here)"),
rendering_parts.text(" # annotation for vis for foo"),
),
)
elif path == "[4]":
return autovisualize.IPythonVisualization(
return treescope.IPythonVisualization(
CustomReprHTMLObject("(html rendering)"),
replace=True,
)
elif path == "[5]":
return autovisualize.IPythonVisualization(
return treescope.IPythonVisualization(
CustomReprHTMLObject("(html rendering)"),
replace=False,
)
elif path == "[6]":
return autovisualize.ChildAutovisualizer(inner_autovisualizer)
return treescope.ChildAutovisualizer(inner_autovisualizer)

def inner_autovisualizer(node, path):
del path
if node == 6:
return autovisualize.CustomTreescopeVisualization(
return treescope.VisualizationFromTreescopePart(
rendering_parts.RenderableAndLineAnnotations(
rendering_parts.text("(child visualiation of 6 goes here)"),
rendering_parts.text(" # annotation for vis for 6"),
),
)

with autovisualize.active_autovisualizer.set_scoped(
with treescope.active_autovisualizer.set_scoped(
autovisualizer_for_test
):
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
)
Expand Down Expand Up @@ -862,7 +861,7 @@ def inner_autovisualizer(node, path):
)

def test_balanced_layout(self):
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
some_nested_object = fixture_lib.DataclassWithOneChild([
["foo"] * 4,
["12345678901234567890"] * 5,
Expand Down Expand Up @@ -974,7 +973,7 @@ def render_and_expand(**kwargs):
)

def test_balanced_layout_after_manual_expansion(self):
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
some_nested_object = [
fixture_lib.DataclassWithOneChild(
[["foo"] * 4, (["baz"] * 5, ["qux"] * 5)]
Expand Down Expand Up @@ -1029,7 +1028,7 @@ def test_balanced_layout_after_manual_expansion(self):
)

def test_balanced_layout_relaxes_height_constraint_once(self):
renderer = default_renderer.active_renderer.get()
renderer = treescope.active_renderer.get()
some_nested_object = [
fixture_lib.DataclassWithOneChild(
[fixture_lib.DataclassWithOneChild(["abcdefghik"] * 20)]
Expand Down
Loading

0 comments on commit 13eede1

Please sign in to comment.