Skip to content

Commit

Permalink
[doc][data] auto-gen GroupedData api (#46925)
Browse files Browse the repository at this point in the history
Use `_autogen` to auto-generate `GroupedData` api documentation, so that
we don't need to enumerate the list of public APIs anymore. Note that in
the new look, constructor is un-folded by default. This is to be
consistent with how we document classes elsewhere (e.g.
https://docs.ray.io/en/latest/data/api/dataset.html)

readthedoc is currently broken on master which is unrelated to this
change

Before:

<img width="1542" alt="Screenshot 2024-08-01 at 3 22 50 PM"
src="https://github.com/user-attachments/assets/a323e897-dd0b-4051-8174-1a7aa8ea9d0c">



After:

<img width="1633" alt="Screenshot 2024-08-01 at 3 22 40 PM"
src="https://github.com/user-attachments/assets/1270a0dd-0239-49e0-ae1e-46d0c591f03f">

Test:
- CI

---------

Signed-off-by: can <[email protected]>
  • Loading branch information
can-anyscale authored Aug 7, 2024
1 parent 85eaffd commit 23453a2
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 44 deletions.
6 changes: 4 additions & 2 deletions doc/source/_templates/autosummary/class_v2.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
.. currentmodule:: {{ module }}

{% if name | has_public_constructor(module) %}
{{ name }}
{{ '-' * name | length }}

.. currentmodule:: {{ module }}

.. autoclass:: {{ objname }}
{% endif %}

{% block methods %}
{% if methods %}
Expand Down
6 changes: 6 additions & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,11 @@ def filter_out_undoc_class_members(member_name, class_name, module_name):
return ""


def has_public_constructor(class_name, module_name):
cls = getattr(import_module(module_name), class_name)
return _is_public_api(cls)


def get_api_groups(method_names, class_name, module_name):
api_groups = set()
cls = getattr(import_module(module_name), class_name)
Expand Down Expand Up @@ -443,6 +448,7 @@ def _is_api_group(obj, group):
FILTERS["filter_out_undoc_class_members"] = filter_out_undoc_class_members
FILTERS["get_api_groups"] = get_api_groups
FILTERS["select_api_group"] = select_api_group
FILTERS["has_public_constructor"] = has_public_constructor


def add_custom_assets(
Expand Down
2 changes: 2 additions & 0 deletions doc/source/data/api/_autogen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
DataIterator
Dataset
Schema
grouped_data.GroupedData
aggregate.AggregateFn
43 changes: 2 additions & 41 deletions doc/source/data/api/grouped_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,5 @@ GroupedData API
GroupedData objects are returned by groupby call:
:meth:`Dataset.groupby() <ray.data.Dataset.groupby>`.

Constructor
-----------

.. autosummary::
:nosignatures:
:toctree: doc/

grouped_data.GroupedData

Computations / Descriptive Stats
--------------------------------

.. autosummary::
:nosignatures:
:toctree: doc/

grouped_data.GroupedData.count
grouped_data.GroupedData.sum
grouped_data.GroupedData.min
grouped_data.GroupedData.max
grouped_data.GroupedData.mean
grouped_data.GroupedData.std

Function Application
--------------------

.. autosummary::
:nosignatures:
:toctree: doc/

grouped_data.GroupedData.aggregate
grouped_data.GroupedData.map_groups

Aggregate Function
------------------

.. autosummary::
:nosignatures:
:toctree: doc/

aggregate.AggregateFn
.. include:: ray.data.grouped_data.GroupedData.rst
.. include:: ray.data.aggregate.AggregateFn.rst
12 changes: 11 additions & 1 deletion python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from ray.data.dataset import DataBatch, Dataset
from ray.util.annotations import PublicAPI

CDS_API_GROUP = "Computations or Descriptive Stats"
FA_API_GROUP = "Function Application"


class _MultiColumnSortedKey:
"""Represents a tuple of group keys with a ``__lt__`` method
Expand All @@ -32,7 +35,6 @@ def __repr__(self) -> str:
return "T" + self.data.__repr__()


@PublicAPI
class GroupedData:
"""Represents a grouped dataset created by calling ``Dataset.groupby()``.
Expand All @@ -57,6 +59,7 @@ def __repr__(self) -> str:
f"{self.__class__.__name__}(dataset={self._dataset}, " f"key={self._key!r})"
)

@PublicAPI(api_group=FA_API_GROUP)
def aggregate(self, *aggs: AggregateFn) -> Dataset:
"""Implements an accumulator-based aggregation.
Expand Down Expand Up @@ -102,6 +105,7 @@ def _aggregate_on(
)
return self.aggregate(*aggs)

@PublicAPI(api_group=FA_API_GROUP)
def map_groups(
self,
fn: UserDefinedFunction[DataBatch, DataBatch],
Expand Down Expand Up @@ -272,6 +276,7 @@ def wrapped_fn(batch, *args, **kwargs):
**ray_remote_args,
)

@PublicAPI(api_group=CDS_API_GROUP)
def count(self) -> Dataset:
"""Compute count aggregation.
Expand All @@ -288,6 +293,7 @@ def count(self) -> Dataset:
"""
return self.aggregate(Count())

@PublicAPI(api_group=CDS_API_GROUP)
def sum(
self, on: Union[str, List[str]] = None, ignore_nulls: bool = True
) -> Dataset:
Expand Down Expand Up @@ -331,6 +337,7 @@ def sum(
"""
return self._aggregate_on(Sum, on, ignore_nulls)

@PublicAPI(api_group=CDS_API_GROUP)
def min(
self, on: Union[str, List[str]] = None, ignore_nulls: bool = True
) -> Dataset:
Expand Down Expand Up @@ -369,6 +376,7 @@ def min(
"""
return self._aggregate_on(Min, on, ignore_nulls)

@PublicAPI(api_group=CDS_API_GROUP)
def max(
self, on: Union[str, List[str]] = None, ignore_nulls: bool = True
) -> Dataset:
Expand Down Expand Up @@ -407,6 +415,7 @@ def max(
"""
return self._aggregate_on(Max, on, ignore_nulls)

@PublicAPI(api_group=CDS_API_GROUP)
def mean(
self, on: Union[str, List[str]] = None, ignore_nulls: bool = True
) -> Dataset:
Expand Down Expand Up @@ -445,6 +454,7 @@ def mean(
"""
return self._aggregate_on(Mean, on, ignore_nulls)

@PublicAPI(api_group=CDS_API_GROUP)
def std(
self,
on: Union[str, List[str]] = None,
Expand Down

0 comments on commit 23453a2

Please sign in to comment.