Skip to content

Commit

Permalink
[Dashboard] Optimize and backpressure actor_head.py (ray-project#29580)
Browse files Browse the repository at this point in the history
Signed-off-by: SangBin Cho <[email protected]>

This optimizes the actor head CPU usage and guarantees a stable API response from the dashboard under lots of actor events published to drivers. The below script is used for testing, and I could reproduce the same level of delay as many_nodes_actor_test (250 nodes + 10k actors)
  • Loading branch information
rkooo567 committed Nov 11, 2022
1 parent 27201bb commit 9da53e3
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 390 deletions.
67 changes: 8 additions & 59 deletions dashboard/datacenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.memory_utils as memory_utils

# TODO(fyrestone): Not import from dashboard module.
from ray.dashboard.modules.actor.actor_utils import actor_classname_from_task_spec
from ray.dashboard.utils import Dict, Signal, async_loop_forever
from ray.dashboard.utils import (
Dict,
Signal,
async_loop_forever,
MutableNotificationDict,
)

logger = logging.getLogger(__name__)

Expand All @@ -25,7 +28,7 @@ class DataSource:
node_physical_stats = Dict()
# {actor id hex(str): actor table data(dict of ActorTableData
# in gcs.proto)}
actors = Dict()
actors = MutableNotificationDict()
# {job id hex(str): job table data(dict of JobTableData in gcs.proto)}
jobs = Dict()
# {node id hex(str): dashboard agent [http port(int), grpc port(int)]}
Expand All @@ -39,11 +42,7 @@ class DataSource:
# {node id hex(str): worker list}
node_workers = Dict()
# {node id hex(str): {actor id hex(str): actor table data}}
node_actors = Dict()
# {job id hex(str): worker list}
job_workers = Dict()
# {job id hex(str): {actor id hex(str): actor table data}}
job_actors = Dict()
node_actors = MutableNotificationDict()
# {worker id(str): core worker stats}
core_worker_stats = Dict()
# {job id hex(str): {event id(str): event dict}}
Expand Down Expand Up @@ -81,20 +80,16 @@ async def purge():
@classmethod
@async_loop_forever(dashboard_consts.ORGANIZE_DATA_INTERVAL_SECONDS)
async def organize(cls):
job_workers = {}
node_workers = {}
core_worker_stats = {}
# await inside for loop, so we create a copy of keys().
for node_id in list(DataSource.nodes.keys()):
workers = await cls.get_node_workers(node_id)
for worker in workers:
job_id = worker["jobId"]
job_workers.setdefault(job_id, []).append(worker)
for stats in worker.get("coreWorkerStats", []):
worker_id = stats["workerId"]
core_worker_stats[worker_id] = stats
node_workers[node_id] = workers
DataSource.job_workers.reset(job_workers)
DataSource.node_workers.reset(node_workers)
DataSource.core_worker_stats.reset(core_worker_stats)

Expand Down Expand Up @@ -281,52 +276,6 @@ async def _get_actor(actor):
actor["processStats"] = actor_process_stats
return actor

@classmethod
async def get_actor_creation_tasks(cls):
# Collect infeasible tasks in worker nodes.
infeasible_tasks = sum(
(
list(node_stats.get("infeasibleTasks", []))
for node_stats in DataSource.node_stats.values()
),
[],
)
# Collect infeasible actor creation tasks in gcs.
infeasible_tasks.extend(
list(DataSource.gcs_scheduling_stats.get("infeasibleTasks", []))
)
new_infeasible_tasks = []
for task in infeasible_tasks:
task = dict(task)
task["actorClass"] = actor_classname_from_task_spec(task)
task["state"] = "INFEASIBLE"
new_infeasible_tasks.append(task)

# Collect pending tasks in worker nodes.
resource_pending_tasks = sum(
(
list(data.get("readyTasks", []))
for data in DataSource.node_stats.values()
),
[],
)
# Collect pending actor creation tasks in gcs.
resource_pending_tasks.extend(
list(DataSource.gcs_scheduling_stats.get("readyTasks", []))
)
new_resource_pending_tasks = []
for task in resource_pending_tasks:
task = dict(task)
task["actorClass"] = actor_classname_from_task_spec(task)
task["state"] = "PENDING_RESOURCES"
new_resource_pending_tasks.append(task)

results = {
task["actorCreationTaskSpec"]["actorId"]: task
for task in new_resource_pending_tasks + new_infeasible_tasks
}
return results

@classmethod
async def get_memory_table(
cls,
Expand Down
158 changes: 55 additions & 103 deletions dashboard/modules/actor/actor_head.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,28 @@
import asyncio
from collections import deque
import logging
import os
import time

from collections import deque

import aiohttp.web

import ray._private.ray_constants as ray_constants
import ray._private.utils
import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.utils as dashboard_utils
from ray._private.gcs_pubsub import GcsAioActorSubscriber
from ray.core.generated import (
core_worker_pb2,
core_worker_pb2_grpc,
gcs_service_pb2,
gcs_service_pb2_grpc,
node_manager_pb2_grpc,
)
from ray.dashboard.datacenter import DataOrganizer, DataSource
from ray.dashboard.modules.actor import actor_consts, actor_utils
from ray.dashboard.modules.actor.actor_utils import actor_classname_from_func_descriptor
from ray.dashboard.datacenter import DataSource
from ray.dashboard.modules.actor import actor_consts
from ray.dashboard.optional_utils import rest_response

try:
from grpc import aio as aiogrpc
except ImportError:
from grpc.experimental import aio as aiogrpc

logger = logging.getLogger(__name__)
routes = dashboard_optional_utils.ClassMethodRouteTable

MAX_ACTORS_TO_CACHE = int(os.environ.get("RAY_DASHBOARD_MAX_ACTORS_TO_CACHE", 1000))
ACTOR_CLEANUP_FREQUENCY = 10 # seconds
ACTOR_CLEANUP_FREQUENCY = 1 # seconds


def actor_table_data_to_dict(message):
Expand All @@ -43,7 +34,6 @@ def actor_table_data_to_dict(message):
"jobId",
"workerId",
"rayletId",
"actorCreationDummyObjectId",
"callerId",
"taskId",
"parentTaskId",
Expand All @@ -64,45 +54,26 @@ def actor_table_data_to_dict(message):
"state",
"name",
"numRestarts",
"functionDescriptor",
"timestamp",
"numExecutedTasks",
"className",
}
light_message = {k: v for (k, v) in orig_message.items() if k in fields}
if "functionDescriptor" in light_message:
actor_class = actor_classname_from_func_descriptor(
light_message["functionDescriptor"]
)
light_message["actorClass"] = actor_class
light_message["actorClass"] = orig_message["className"]
return light_message


class ActorHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._stubs = {}
# ActorInfoGcsService
self._gcs_actor_info_stub = None
# A queue of dead actors in order of when they died
self.dead_actors_queue = deque()
DataSource.nodes.signal.append(self._update_stubs)

async def _update_stubs(self, change):
if change.old:
node_id, node_info = change.old
self._stubs.pop(node_id)
if change.new:
# TODO(fyrestone): Handle exceptions.
node_id, node_info = change.new
address = "{}:{}".format(
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
)
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True
)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
self._stubs[node_id] = stub
# -- Internal states --
self.total_published_events = 0
self.subscriber_queue_size = 0
self.accumulative_event_processing_s = 0

async def _update_actors(self):
# Get all actor info.
Expand All @@ -121,18 +92,14 @@ async def _update_actors(self):
# Update actors.
DataSource.actors.reset(actors)
# Update node actors and job actors.
job_actors = {}
node_actors = {}
for actor_id, actor_table_data in actors.items():
job_id = actor_table_data["jobId"]
node_id = actor_table_data["address"]["rayletId"]
job_actors.setdefault(job_id, {})[actor_id] = actor_table_data
# Update only when node_id is not Nil.
if node_id != actor_consts.NIL_NODE_ID:
node_actors.setdefault(node_id, {})[
actor_id
] = actor_table_data
DataSource.job_actors.reset(job_actors)
DataSource.node_actors.reset(node_actors)
logger.info("Received %d actor info from GCS.", len(actors))
break
Expand All @@ -153,26 +120,21 @@ def process_actor_data_from_pubsub(actor_id, actor_table_data):
# If actor is not new registered but updated, we only update
# states related fields.
if actor_table_data["state"] != "DEPENDENCIES_UNREADY":
actor_table_data_copy = dict(DataSource.actors[actor_id])
actors = DataSource.actors[actor_id]
for k in state_keys:
actor_table_data_copy[k] = actor_table_data[k]
actor_table_data = actor_table_data_copy
actors[k] = actor_table_data[k]
actor_table_data = actors
actor_id = actor_table_data["actorId"]
job_id = actor_table_data["jobId"]
node_id = actor_table_data["address"]["rayletId"]
if actor_table_data["state"] == "DEAD":
self.dead_actors_queue.append(actor_id)
# Update actors.
DataSource.actors[actor_id] = actor_table_data
# Update node actors (only when node_id is not Nil).
if node_id != actor_consts.NIL_NODE_ID:
node_actors = dict(DataSource.node_actors.get(node_id, {}))
node_actors = DataSource.node_actors.get(node_id, {})
node_actors[actor_id] = actor_table_data
DataSource.node_actors[node_id] = node_actors
# Update job actors.
job_actors = dict(DataSource.job_actors.get(job_id, {}))
job_actors[actor_id] = actor_table_data
DataSource.job_actors[job_id] = job_actors

# Receive actors from channel.
gcs_addr = self._dashboard_head.gcs_address
Expand All @@ -181,11 +143,32 @@ def process_actor_data_from_pubsub(actor_id, actor_table_data):

while True:
try:
actor_id, actor_table_data = await subscriber.poll()
if actor_id is not None:
# Convert to lower case hex ID.
actor_id = actor_id.hex()
process_actor_data_from_pubsub(actor_id, actor_table_data)
published = await subscriber.poll(batch_size=200)
start = time.monotonic()
for actor_id, actor_table_data in published:
if actor_id is not None:
# Convert to lower case hex ID.
actor_id = actor_id.hex()
process_actor_data_from_pubsub(actor_id, actor_table_data)

# Yield so that we can give time for
# user-facing APIs to reply to the frontend.
elapsed = time.monotonic() - start
await asyncio.sleep(elapsed)

# Update the internal states for debugging.
self.accumulative_event_processing_s += elapsed
self.total_published_events += len(published)
self.subscriber_queue_size = subscriber.queue_size
logger.debug(
f"Processing takes {elapsed}. Total process: " f"{len(published)}"
)
logger.debug(
"Processing throughput: "
f"{self.total_published_events / self.accumulative_event_processing_s}" # noqa
" / s"
)
logger.debug(f"queue size: {self.subscriber_queue_size}")
except Exception:
logger.exception("Error processing actor info from GCS.")

Expand All @@ -204,27 +187,25 @@ async def _cleanup_actors(self):
actor_id = self.dead_actors_queue.popleft()
if actor_id in DataSource.actors:
actor = DataSource.actors.pop(actor_id)
job_id = actor["jobId"]
del DataSource.job_actors[job_id][actor_id]
node_id = actor["address"].get("rayletId")
if node_id:
if node_id and node_id != actor_consts.NIL_NODE_ID:
del DataSource.node_actors[node_id][actor_id]
await asyncio.sleep(ACTOR_CLEANUP_FREQUENCY)
except Exception:
logger.exception("Error cleaning up actor info from GCS.")

@routes.get("/logical/actor_groups")
async def get_actor_groups(self, req) -> aiohttp.web.Response:
actors = await DataOrganizer.get_all_actors()
actor_creation_tasks = await DataOrganizer.get_actor_creation_tasks()
# actor_creation_tasks have some common interface with actors,
# and they get processed and shown in tandem in the logical view
# hence we merge them together before constructing actor groups.
actors.update(actor_creation_tasks)
actor_groups = actor_utils.construct_actor_groups(actors)
return rest_response(
success=True, message="Fetched actor groups.", actor_groups=actor_groups
)
def get_internal_states(self):
states = {
"total_published_events": self.total_published_events,
"total_dead_actors": len(self.dead_actors_queue),
"total_actors": len(DataSource.actors),
"queue_size": self.subscriber_queue_size,
}
if self.accumulative_event_processing_s > 0:
states["event_processing_per_s"] = (
self.total_published_events / self.accumulative_event_processing_s
)
return states

@routes.get("/logical/actors")
@dashboard_optional_utils.aiohttp_cache
Expand All @@ -233,35 +214,6 @@ async def get_all_actors(self, req) -> aiohttp.web.Response:
success=True, message="All actors fetched.", actors=DataSource.actors
)

@routes.get("/logical/kill_actor")
async def kill_actor(self, req) -> aiohttp.web.Response:
try:
actor_id = req.query["actorId"]
ip_address = req.query["ipAddress"]
port = req.query["port"]
except KeyError:
return rest_response(success=False, message="Bad Request")
try:
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = ray._private.utils.init_grpc_channel(
f"{ip_address}:{port}", options=options, asynchronous=True
)
stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel)

await stub.KillActor(
core_worker_pb2.KillActorRequest(
intended_actor_id=ray._private.utils.hex_to_binary(actor_id)
)
)

except aiogrpc.AioRpcError:
# This always throws an exception because the worker
# is killed and the channel is closed on the worker side
# before this handler, however it deletes the actor correctly.
pass

return rest_response(success=True, message=f"Killed actor with id {actor_id}")

async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
Expand Down
Loading

0 comments on commit 9da53e3

Please sign in to comment.