From 74849ded913c88279f0859c8b104c5509ac9bc99 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 26 Jul 2022 09:13:22 -0700 Subject: [PATCH 1/2] completed. --- dashboard/modules/state/state_head.py | 3 - dashboard/state_aggregator.py | 9 ++ python/ray/_private/ray_logging.py | 2 - python/ray/_private/test_utils.py | 6 +- python/ray/experimental/state/api.py | 138 +++++++++++------- python/ray/experimental/state/common.py | 108 ++++++++++---- python/ray/experimental/state/state_cli.py | 56 +++++-- .../ray/experimental/state/state_manager.py | 1 - python/ray/scripts/scripts.py | 2 +- python/ray/tests/test_state_api.py | 118 ++++++++++++--- 10 files changed, 322 insertions(+), 121 deletions(-) diff --git a/dashboard/modules/state/state_head.py b/dashboard/modules/state/state_head.py index fc2a60d862123..700824ae159c2 100644 --- a/dashboard/modules/state/state_head.py +++ b/dashboard/modules/state/state_head.py @@ -321,7 +321,6 @@ async def list_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response: node_ip = req.query.get("node_ip", None) timeout = int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT)) - # TODO(sang): Do input validation from the middleware instead. if not node_id and not node_ip: return self._reply( success=False, @@ -354,8 +353,6 @@ async def list_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response: @routes.get("/api/v0/logs/{media_type}") @RateLimitedModule.enforce_max_concurrent_calls async def get_logs(self, req: aiohttp.web.Request): - # TODO(sang): We need a better error handling for streaming - # when we refactor the server framework. options = GetLogOptions( timeout=int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT)), node_id=req.query.get("node_id", None), diff --git a/dashboard/state_aggregator.py b/dashboard/state_aggregator.py index 72712e0bf1da7..d3e659eff4c15 100644 --- a/dashboard/state_aggregator.py +++ b/dashboard/state_aggregator.py @@ -613,6 +613,9 @@ async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse partial_failure_warning=result.partial_failure_warning, warnings=result.warnings, num_after_truncation=result.num_after_truncation, + # Currently, there's no filtering support for summary, + # so we don't calculate this separately. + num_filtered=len(result.result), ) async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiResponse: @@ -633,6 +636,9 @@ async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiRespons partial_failure_warning=result.partial_failure_warning, warnings=result.warnings, num_after_truncation=result.num_after_truncation, + # Currently, there's no filtering support for summary, + # so we don't calculate this separately. + num_filtered=len(result.result), ) async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiResponse: @@ -653,6 +659,9 @@ async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiRespon partial_failure_warning=result.partial_failure_warning, warnings=result.warnings, num_after_truncation=result.num_after_truncation, + # Currently, there's no filtering support for summary, + # so we don't calculate this separately. + num_filtered=len(result.result), ) def _message_to_dict( diff --git a/python/ray/_private/ray_logging.py b/python/ray/_private/ray_logging.py index cdf805c7462c2..9907b04f5d2b2 100644 --- a/python/ray/_private/ray_logging.py +++ b/python/ray/_private/ray_logging.py @@ -287,8 +287,6 @@ def setup_and_get_worker_interceptor_logger( is_for_stdout=is_for_stdout, ) logger.addHandler(handler) - # TODO(sang): Add 0 or 1 to decide whether - # or not logs are streamed to drivers. handler.setFormatter(logging.Formatter("%(message)s")) # Avoid messages are propagated to parent loggers. logger.propagate = False diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 1179b0b78c4e0..25789fa1af407 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -36,6 +36,8 @@ from ray.scripts.scripts import main as ray_main from ray.util.queue import Empty, Queue, _QueueActor +logger = logging.getLogger(__name__) + try: from prometheus_client.parser import text_string_to_metric_families except (ImportError, ModuleNotFoundError): @@ -355,8 +357,8 @@ def wait_for_condition( try: if condition_predictor(**kwargs): return - except Exception as ex: - last_ex = ex + except Exception: + last_ex = ray._private.utils.format_error_message(traceback.format_exc()) time.sleep(retry_interval_ms / 1000.0) message = "The condition wasn't met before the timeout expired." if last_ex is not None: diff --git a/python/ray/experimental/state/api.py b/python/ray/experimental/state/api.py index d175cfb3d510c..51af77f7f38e6 100644 --- a/python/ray/experimental/state/api.py +++ b/python/ray/experimental/state/api.py @@ -311,61 +311,77 @@ def get( assert len(result) == 1 return result[0] - def _print_api_warning(self, resource: StateResource, api_response: dict): + def _print_api_warning( + self, + resource: StateResource, + api_response: dict, + warn_data_source_not_available: bool = True, + warn_data_truncation: bool = True, + warn_limit: bool = True, + warn_server_side_warnings: bool = True, + ): """Print the API warnings. - We print warnings for users: - 1. when some data sources are not available - 2. when results were truncated at the data source - 3. when results were limited - 4. when callsites not enabled for listing objects - Args: resource: Resource names, i.e. 'jobs', 'actors', 'nodes', see `StateResource` for details. api_response: The dictionarified `ListApiResponse` or `SummaryApiResponse`. + warn_data_source_not_available: Warn when some data sources + are not available. + warn_data_truncation: Warn when results were truncated at + the data source. + warn_limit: Warn when results were limited. + warn_server_side_warnings: Warn when the server side generates warnings + (E.g., when callsites not enabled for listing objects) """ # Print warnings if anything was given. - warning_msgs = api_response.get("partial_failure_warning", None) - if warning_msgs: - warnings.warn(warning_msgs) - - # Print warnings if data is truncated at the data source. - num_after_truncation = api_response["num_after_truncation"] - total = api_response["total"] - if total > num_after_truncation: - # NOTE(rickyyx): For now, there's not much users could do (neither can we), - # with hard truncation. Unless we allow users to set a higher - # `RAY_MAX_LIMIT_FROM_DATA_SOURCE`, the data will always be truncated at the - # data source. - warnings.warn( - ( - f"{num_after_truncation} ({total} total) {resource.value} " - "are retrieved from the data source. " - f"{total - num_after_truncation} entries have been truncated. " - f"Max of {num_after_truncation} entries are retrieved from data " - "source to prevent over-sized payloads." - ), - ) + if warn_data_source_not_available: + warning_msgs = api_response.get("partial_failure_warning", None) + if warning_msgs: + warnings.warn(warning_msgs) + + if warn_data_truncation: + # Print warnings if data is truncated at the data source. + num_after_truncation = api_response["num_after_truncation"] + total = api_response["total"] + if total > num_after_truncation: + # NOTE(rickyyx): For now, there's not much users + # could do (neither can we), with hard truncation. + # Unless we allow users to set a higher + # `RAY_MAX_LIMIT_FROM_DATA_SOURCE`, the data will + # always be truncated at the data source. + warnings.warn( + ( + "The returned data may contain incomplete result. " + f"{num_after_truncation} ({total} total from the cluster) " + f"{resource.value} are retrieved from the data source. " + f"{total - num_after_truncation} entries have been truncated. " + f"Max of {num_after_truncation} entries are retrieved " + "from data source to prevent over-sized payloads." + ), + ) - # Print warnings if return data is limited at the API server due to - # limit enforced at the server side - num_filtered = api_response["num_filtered"] - data = api_response["result"] - if num_filtered > len(data): - warnings.warn( - ( - f"{len(data)}/{num_filtered} {resource.value} returned. " - "Use `--filter` to reduce the amount of data to return or " - "setting a higher limit with `--limit` to see all data. " - ), - ) + if warn_limit: + # Print warnings if return data is limited at the API server due to + # limit enforced at the server side + num_filtered = api_response["num_filtered"] + data = api_response["result"] + if num_filtered > len(data): + warnings.warn( + ( + f"Limit last {len(data)} entries " + f"(Total {num_filtered}). Use `--filter` to reduce " + "the amount of data to return or " + "setting a higher limit with `--limit` to see all data. " + ), + ) - # Print the additional warnings. - warnings_to_print = api_response.get("warnings", []) - if warnings_to_print: - for warning_to_print in warnings_to_print: - warnings.warn(warning_to_print) + if warn_server_side_warnings: + # Print the additional warnings. + warnings_to_print = api_response.get("warnings", []) + if warnings_to_print: + for warning_to_print in warnings_to_print: + warnings.warn(warning_to_print) def _raise_on_missing_output(self, resource: StateResource, api_response: dict): """Raise an exception when the API resopnse contains a missing output. @@ -380,15 +396,30 @@ def _raise_on_missing_output(self, resource: StateResource, api_response: dict): see `StateResource` for details. api_response: The dictionarified `ListApiResponse` or `SummaryApiResponse`. """ + # Raise an exception if there are partial failures that cause missing output. warning_msgs = api_response.get("partial_failure_warning", None) - # TODO(sang) raise an exception on truncation after - # https://github.com/ray-project/ray/pull/26801. if warning_msgs: raise RayStateApiException( - f"Failed to retrieve all {resource.value} from the cluster. " - f"It can happen when some of {resource.value} information is not " - "reachable or the returned data is truncated because it is too large. " - "To allow having missing output, set `raise_on_missing_output=False`. " + f"Failed to retrieve all {resource.value} from the cluster because" + "they are not reachable due to query failures to the data sources. " + "To avoid raising an exception and allow having missing output, " + "set `raise_on_missing_output=False`. " + ) + # Raise an exception is there is data truncation that cause missing output. + total = api_response["total"] + num_after_truncation = api_response["num_after_truncation"] + + if total != num_after_truncation: + raise RayStateApiException( + f"Failed to retrieve all {resource.value} from the cluster because " + "they are not reachable due to data truncation. It happens " + "when the returned data is too large " + # When the data is truncated, the truncation + # threshold == num_after_truncation. We cannot set this to env + # var because the CLI side might not have the correct env var. + f"(> {num_after_truncation}) " + "To avoid raising an exception and allow having missing output, " + "set `raise_on_missing_output=False`. " ) def list( @@ -474,8 +505,9 @@ def summary( ) if raise_on_missing_output: self._raise_on_missing_output(resource, summary_api_response) - # TODO(sang): Add warning after - # # https://github.com/ray-project/ray/pull/26801 is merged. + if _explain: + # There's no limit applied to summary, so we shouldn't warn. + self._print_api_warning(resource, summary_api_response, warn_limit=False) return summary_api_response["result"]["node_id_to_summary"] diff --git a/python/ray/experimental/state/common.py b/python/ray/experimental/state/common.py index cd21fdea2c2cd..dc5b87f309d4a 100644 --- a/python/ray/experimental/state/common.py +++ b/python/ray/experimental/state/common.py @@ -174,13 +174,31 @@ class State(StateSchema): """ @classmethod - def columns(cls) -> Set[str]: - """Return a list of all columns.""" - cols = set() + def list_columns(cls) -> List[str]: + """Return a list of columns. + + The order of columns is defined by the order + of attributes from the dataclass. + + E.g., + @dataclass + class A(StateSchema): + a: str + b: str + c: str + + -> ["a", "b", "c"] + """ + cols = [] for f in fields(cls): - cols.add(f.name) + cols.append(f.name) return cols + @classmethod + def columns(cls) -> Set[str]: + """Return a set of all columns.""" + return set(cls.list_columns()) + @classmethod def filterable_columns(cls) -> Set[str]: """Return a list of filterable columns""" @@ -302,6 +320,8 @@ class ActorState(StateSchema): #: The id of the actor. actor_id: str = state_column(filterable=True) + #: The class name of the actor. + class_name: str = state_column(filterable=True) #: The state of the actor. #: #: - DEPENDENCIES_UNREADY: Actor is waiting for dependency to be ready. @@ -318,8 +338,6 @@ class ActorState(StateSchema): #: but means the actor was dead more than once. #: - DEAD: The actor is permanatly dead. state: TypeActorStatus = state_column(filterable=True) - #: The class name of the actor. - class_name: str = state_column(filterable=True) #: The name of the actor given by the `name` argument. name: Optional[str] = state_column(filterable=True) #: The pid of the actor. 0 if it is not created yet. @@ -340,6 +358,8 @@ class PlacementGroupState(StateSchema): #: The id of the placement group. placement_group_id: str = state_column(filterable=True) + #: The name of the placement group if it is given by the name argument. + name: str = state_column(filterable=True) #: The state of the placement group. #: #: - PENDING: The placement group creation is pending scheduling. @@ -351,8 +371,6 @@ class PlacementGroupState(StateSchema): #: - RESCHEDULING: The placement group is rescheduling because some of #: bundles are dead because they were on dead nodes. state: TypePlacementGroupStatus = state_column(filterable=True) - #: The name of the placement group if it is given by the name argument. - name: str = state_column(filterable=True) #: The bundle specification of the placement group. bundles: dict = state_column(filterable=False, detail=True) #: True if the placement group is detached. False otherwise. @@ -383,6 +401,13 @@ class NodeState(StateSchema): class JobState(JobInfo, StateSchema): """The state of the job that's submitted by Ray's Job APIs""" + @classmethod + def list_columns(cls) -> List[str]: + cols = ["job_id"] + for f in fields(cls): + cols.append(f.name) + return cols + @classmethod def filterable_columns(cls) -> Set[str]: return {"status", "entrypoint", "error_type"} @@ -474,14 +499,8 @@ class ObjectState(StateSchema): #: The id of the object. object_id: str = state_column(filterable=True) - #: The pid of the owner. - pid: int = state_column(filterable=True) - #: The ip address of the owner. - ip: str = state_column(filterable=True) #: The size of the object in mb. object_size: int = state_column(filterable=True) - #: The callsite of the object. - call_site: str = state_column(filterable=True) #: The status of the task that creates the object. #: #: - NIL: We don't have a status for this task because we are not the owner or the @@ -499,14 +518,6 @@ class ObjectState(StateSchema): #: to the remote worker + queueing time from the execution side. #: - RUNNING: The task that is running. task_status: TypeTaskStatus = state_column(filterable=True) - #: The worker type that creates the object. - #: - #: - WORKER: The regular Ray worker process that executes tasks or - #: instantiates an actor. - #: - DRIVER: The driver (Python script that calls `ray.init`). - #: - SPILL_WORKER: The worker that spills objects. - #: - RESTORE_WORKER: The worker that restores objects. - type: TypeWorkerType = state_column(filterable=True) #: The reference type of the object. #: See :ref:`Debugging with Ray Memory ` for more details. #: @@ -522,6 +533,20 @@ class ObjectState(StateSchema): #: `a = ray.put(1)` -> `b = ray.put([a])`. a is serialized within a list. #: - UNKNOWN_STATUS: The object ref status is unkonwn. reference_type: TypeReferenceType = state_column(filterable=True) + #: The callsite of the object. + call_site: str = state_column(filterable=True) + #: The worker type that creates the object. + #: + #: - WORKER: The regular Ray worker process that executes tasks or + #: instantiates an actor. + #: - DRIVER: The driver (Python script that calls `ray.init`). + #: - SPILL_WORKER: The worker that spills objects. + #: - RESTORE_WORKER: The worker that restores objects. + type: TypeWorkerType = state_column(filterable=True) + #: The pid of the owner. + pid: int = state_column(filterable=True) + #: The ip address of the owner. + ip: str = state_column(filterable=True) @dataclass(init=True) @@ -532,11 +557,11 @@ class RuntimeEnvState(StateSchema): runtime_env: str = state_column(filterable=True) #: Whether or not the runtime env creation has succeeded. success: bool = state_column(filterable=True) - #: The node id of this runtime environment. - node_id: str = state_column(filterable=True) #: The latency of creating the runtime environment. #: Available if the runtime env is successfully created. creation_time_ms: Optional[float] = state_column(filterable=False) + #: The node id of this runtime environment. + node_id: str = state_column(filterable=True) #: The number of actors and tasks that use this runtime environment. ref_cnt: int = state_column(detail=True, filterable=False) #: The error message if the runtime environment creation has failed. @@ -616,7 +641,7 @@ class ListApiResponse: ObjectState, RuntimeEnvState, ] - ] = None + ] # List API can have a partial failure if queries to # all sources fail. For example, getting object states # require to ping all raylets, and it is possible some of @@ -627,6 +652,13 @@ class ListApiResponse: # A list of warnings to print. warnings: Optional[List[str]] = None + def __post_init__(self): + assert self.total is not None + assert self.num_after_truncation is not None + assert self.num_filtered is not None + assert self.result is not None + assert isinstance(self.result, list) + """ Summary API schema @@ -662,7 +694,6 @@ class TaskSummaries: def to_summary(cls, *, tasks: List[Dict]): # NOTE: The argument tasks contains a list of dictionary # that have the same k/v as TaskState. - # TODO(sang): Refactor this to use real dataclass. summary = {} total_tasks = 0 total_actor_tasks = 0 @@ -719,7 +750,6 @@ class ActorSummaries: def to_summary(cls, *, actors: List[Dict]): # NOTE: The argument tasks contains a list of dictionary # that have the same k/v as ActorState. - # TODO(sang): Refactor this to use real dataclass. summary = {} total_actors = 0 @@ -778,7 +808,6 @@ class ObjectSummaries: def to_summary(cls, *, objects: List[Dict]): # NOTE: The argument tasks contains a list of dictionary # that have the same k/v as ObjectState. - # TODO(sang): Refactor this to use real dataclass. summary = {} total_objects = 0 total_size_mb = 0 @@ -854,7 +883,30 @@ class SummaryApiResponse: # Carried over from ListApiResponse # Number of resources returned by data sources after truncation num_after_truncation: int + # Number of resources after filtering + num_filtered: int result: StateSummary = None partial_failure_warning: str = "" # A list of warnings to print. warnings: Optional[List[str]] = None + + +def resource_to_schema(resource: StateResource) -> StateSchema: + if resource == StateResource.ACTORS: + return ActorState + elif resource == StateResource.JOBS: + return JobState + elif resource == StateResource.NODES: + return NodeState + elif resource == StateResource.OBJECTS: + return ObjectState + elif resource == StateResource.PLACEMENT_GROUPS: + return PlacementGroupState + elif resource == StateResource.RUNTIME_ENVS: + return RuntimeEnvState + elif resource == StateResource.TASKS: + return TaskState + elif resource == StateResource.WORKERS: + return WorkerState + else: + assert False, "Unreachable" diff --git a/python/ray/experimental/state/state_cli.py b/python/ray/experimental/state/state_cli.py index 44adb8ebb01b7..9da6dddc14963 100644 --- a/python/ray/experimental/state/state_cli.py +++ b/python/ray/experimental/state/state_cli.py @@ -25,7 +25,9 @@ ListApiOptions, PredicateType, StateResource, + StateSchema, SupportedFilterType, + resource_to_schema, ) logger = logging.getLogger(__name__) @@ -134,16 +136,43 @@ def get_api_server_url() -> str: return api_server_url -def get_table_output(state_data: List) -> str: +def get_table_output(state_data: List, schema: StateSchema) -> str: + """Display the table output. + + The table headers are ordered as the order defined in the dataclass of + `StateSchema`. For example, + + @dataclass + class A(StateSchema): + a: str + b: str + c: str + + will create headers + A B C + ----- + + Args: + state_data: A list of state data. + schema: The schema for the corresponding resource. + + Returns: + The table formatted string. + """ time = datetime.now() header = "=" * 8 + f" List: {time} " + "=" * 8 headers = [] table = [] + cols = schema.list_columns() for data in state_data: for key, val in data.items(): if isinstance(val, dict): data[key] = yaml.dump(val, indent=2) - headers = sorted([key.upper() for key in data.keys()]) + keys = set(data.keys()) + headers = [] + for col in cols: + if col in keys: + headers.append(col.upper()) table.append([data[header.lower()] for header in headers]) return f""" {header} @@ -158,17 +187,19 @@ def get_table_output(state_data: List) -> str: def output_with_format( - state_data: List, format: AvailableFormat = AvailableFormat.DEFAULT + state_data: List, + *, + schema: Optional[StateSchema], + format: AvailableFormat = AvailableFormat.DEFAULT, ) -> str: - # Default is yaml. if format == AvailableFormat.DEFAULT: - return get_table_output(state_data) + return get_table_output(state_data, schema) if format == AvailableFormat.YAML: return yaml.dump(state_data, indent=4, explicit_start=True) elif format == AvailableFormat.JSON: return json.dumps(state_data) elif format == AvailableFormat.TABLE: - return get_table_output(state_data) + return get_table_output(state_data, schema) else: raise ValueError( f"Unexpected format: {format}. " @@ -269,20 +300,25 @@ def format_object_summary_output(state_data: Dict) -> str: def format_get_api_output( state_data: Optional[Dict], id: str, + *, + schema: StateSchema, format: AvailableFormat = AvailableFormat.DEFAULT, ) -> str: if not state_data or len(state_data) == 0: return f"Resource with id={id} not found in the cluster." - return output_with_format(state_data, format) + return output_with_format(state_data, schema=schema, format=format) def format_list_api_output( - state_data: List[Dict], *, format: AvailableFormat = AvailableFormat.DEFAULT + state_data: List[Dict], + *, + schema: StateSchema, + format: AvailableFormat = AvailableFormat.DEFAULT, ) -> str: if len(state_data) == 0: return "No resource in the cluster" - return output_with_format(state_data, format) + return output_with_format(state_data, schema=schema, format=format) def _should_explain(format: AvailableFormat) -> bool: @@ -398,6 +434,7 @@ def get( format_get_api_output( state_data=data, id=id, + schema=resource_to_schema(resource), format=AvailableFormat.YAML, ) ) @@ -541,6 +578,7 @@ def list( print( format_list_api_output( state_data=data, + schema=resource_to_schema(resource), format=format, ) ) diff --git a/python/ray/experimental/state/state_manager.py b/python/ray/experimental/state/state_manager.py index d4b2e6af15104..b95553719e558 100644 --- a/python/ray/experimental/state/state_manager.py +++ b/python/ray/experimental/state/state_manager.py @@ -67,7 +67,6 @@ async def api_with_network_error_handler(*args, **kwargs): or there's a slow network issue causing timeout. Otherwise, the raw network exceptions (e.g., gRPC) will be raised. """ - # TODO(sang): Add a retry policy. try: return await func(*args, **kwargs) except grpc.aio.AioRpcError as e: diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 14496bfe3de9c..741117375e1ff 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -2126,7 +2126,7 @@ def ray_logs( print(f"Node ID: {node_id}") elif node_ip: print(f"Node IP: {node_ip}") - print(output_with_format(logs, format=AvailableFormat.YAML)) + print(output_with_format(logs, schema=None, format=AvailableFormat.YAML)) # If there's an unique match, print the log file. if match_unique: diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index a231991cfde5b..925d97ba73b45 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -89,6 +89,7 @@ AvailableFormat, format_list_api_output, _parse_filter, + summary_state_cli_group, ) from ray.experimental.state.state_cli import get as cli_get from ray.experimental.state.state_cli import list as cli_list @@ -1995,23 +1996,38 @@ async def test_cli_format_print(state_api_manager): result = result.result # If the format is not yaml, it will raise an exception. yaml.load( - format_list_api_output(result, format=AvailableFormat.YAML), + format_list_api_output(result, schema=ActorState, format=AvailableFormat.YAML), Loader=yaml.FullLoader, ) # If the format is not json, it will raise an exception. - json.loads(format_list_api_output(result, format=AvailableFormat.JSON)) + json.loads( + format_list_api_output(result, schema=ActorState, format=AvailableFormat.JSON) + ) # Test a table formatting. - output = format_list_api_output(result, format=AvailableFormat.TABLE) + output = format_list_api_output( + result, schema=ActorState, format=AvailableFormat.TABLE + ) assert "Table:" in output assert "Stats:" in output with pytest.raises(ValueError): - format_list_api_output(result, format="random_format") + format_list_api_output(result, schema=ActorState, format="random_format") # Verify the default format. - output = format_list_api_output(result) + output = format_list_api_output(result, schema=ActorState) assert "Table:" in output assert "Stats:" in output + # Verify the ordering is equal to it is defined in `StateSchema` class. + # Index 8 contains headers + headers = output.split("\n")[8] + cols = ActorState.list_columns() + headers = list(filter(lambda item: item != "", headers.strip().split(" "))) + + for i in range(len(headers)): + header = headers[i].upper() + col = cols[i].upper() + assert header == col + def test_filter(shutdown_only): ray.init() @@ -2132,11 +2148,10 @@ def test_data_truncate(shutdown_only, monkeypatch): runner = CliRunner() with pytest.warns(UserWarning) as record: result = runner.invoke(cli_list, ["placement-groups"]) - # result = list_placement_groups() assert ( - f"{max_limit_data_source} ({max_limit_data_source + 1} total) " - "placement_groups are retrieved from the data source. " - "1 entries have been truncated." in record[0].message.args[0] + f"{max_limit_data_source} ({max_limit_data_source + 1} total " + "from the cluster) placement_groups are retrieved from the " + "data source. 1 entries have been truncated." in record[0].message.args[0] ) assert result.exit_code == 0 @@ -2525,14 +2540,77 @@ def verify(): try: list_tasks(_explain=True, timeout=3) except RayStateApiException as e: - assert "Failed to retrieve all tasks from the cluster." in str(e) + assert "Failed to retrieve all tasks from the cluster" in str(e) + assert "due to query failures to the data sources." in str(e) + else: + assert False + + try: + summarize_tasks(_explain=True, timeout=3) + except RayStateApiException as e: + assert "Failed to retrieve all tasks from the cluster" in str(e) + assert "due to query failures to the data sources." in str(e) + else: + assert False + + # Verify when raise_on_missing_output=False, it prints warnings. + with pytest.warns(None) as record: + list_tasks(raise_on_missing_output=False, _explain=True, timeout=3) + assert len(record) == 1 + + with pytest.warns(None) as record: + summarize_tasks(raise_on_missing_output=False, _explain=True, timeout=3) + assert len(record) == 1 + + # Verify when CLI is used, exceptions are not raised. + with pytest.warns(None) as record: + result = runner.invoke(cli_list, ["tasks", "--timeout=3"]) + assert len(record) == 1 + assert result.exit_code == 0 + + # Verify summary CLI also doesn't raise an exception. + with pytest.warns(None) as record: + result = runner.invoke(summary_state_cli_group, ["tasks", "--timeout=3"]) + assert result.exit_code == 0 + assert len(record) == 1 + return True + + wait_for_condition(verify) + + +def test_raise_on_missing_output_truncation(monkeypatch, shutdown_only): + with monkeypatch.context() as m: + # defer for 10s for the second node. + m.setenv( + "RAY_MAX_LIMIT_FROM_DATA_SOURCE", + "10", + ) + ray.init() + + @ray.remote + def task(): + time.sleep(300) + + tasks = [task.remote() for _ in range(15)] # noqa + + runner = CliRunner() + + # Verify + def verify(): + # Verify when raise_on_missing_output=True, it raises an exception. + try: + list_tasks(_explain=True, timeout=3) + except RayStateApiException as e: + assert "Failed to retrieve all tasks from the cluster" in str(e) + assert "(> 10)" in str(e) else: assert False try: summarize_tasks(_explain=True, timeout=3) except RayStateApiException as e: - assert "Failed to retrieve all tasks from the cluster." in str(e) + assert "Failed to retrieve all tasks from the cluster" in str(e) + assert "(> 10)" in str(e) else: assert False @@ -2541,11 +2619,9 @@ def verify(): list_tasks(raise_on_missing_output=False, _explain=True, timeout=3) assert len(record) == 1 - # TODO(sang): Add warning after https://github.com/ray-project/ray/pull/26801 - # is merged. - # with pytest.warns(None) as record: - # summarize_tasks(raise_on_missing_output=False, _explain=True, timeout=3) - # assert len(record) == 1 + with pytest.warns(None) as record: + summarize_tasks(raise_on_missing_output=False, _explain=True, timeout=3) + assert len(record) == 1 # Verify when CLI is used, exceptions are not raised. with pytest.warns(None) as record: @@ -2553,13 +2629,11 @@ def verify(): assert len(record) == 1 assert result.exit_code == 0 - # TODO(sang): Add warning after https://github.com/ray-project/ray/pull/26801 - # is merged. # Verify summary CLI also doesn't raise an exception. - # with pytest.warns(None) as record: - # result = runner.invoke(task_summary, ["--timeout=3"]) - # assert result.exit_code == 0 - # assert len(record) == 1 + with pytest.warns(None) as record: + result = runner.invoke(summary_state_cli_group, ["tasks", "--timeout=3"]) + assert result.exit_code == 0 + assert len(record) == 1 return True wait_for_condition(verify) From e9dc51a3b83d10061caeab10192e25f70c7a8371 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 27 Jul 2022 01:42:19 -0700 Subject: [PATCH 2/2] done --- python/ray/experimental/state/common.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/python/ray/experimental/state/common.py b/python/ray/experimental/state/common.py index dc5b87f309d4a..fb4431f7a1322 100644 --- a/python/ray/experimental/state/common.py +++ b/python/ray/experimental/state/common.py @@ -175,20 +175,7 @@ class State(StateSchema): @classmethod def list_columns(cls) -> List[str]: - """Return a list of columns. - - The order of columns is defined by the order - of attributes from the dataclass. - - E.g., - @dataclass - class A(StateSchema): - a: str - b: str - c: str - - -> ["a", "b", "c"] - """ + """Return a list of columns.""" cols = [] for f in fields(cls): cols.append(f.name)