diff --git a/dashboard/__init__.py b/dashboard/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/dashboard/agent.py b/dashboard/agent.py new file mode 100644 index 0000000000000..818770f45ca35 --- /dev/null +++ b/dashboard/agent.py @@ -0,0 +1,229 @@ +import argparse +import asyncio +import logging +import logging.handlers +import os +import sys +import traceback + +import aiohttp +import aioredis +from grpc.experimental import aio as aiogrpc + +import ray +import ray.new_dashboard.consts as dashboard_consts +import ray.new_dashboard.utils as dashboard_utils +import ray.ray_constants as ray_constants +import ray.services +import ray.utils +import psutil + +logger = logging.getLogger(__name__) + +aiogrpc.init_grpc_aio() + + +class DashboardAgent(object): + def __init__(self, + redis_address, + redis_password=None, + temp_dir=None, + log_dir=None, + node_manager_port=None, + object_store_name=None, + raylet_name=None): + """Initialize the DashboardAgent object.""" + self._agent_cls_list = dashboard_utils.get_all_modules( + dashboard_utils.DashboardAgentModule) + ip, port = redis_address.split(":") + # Public attributes are accessible for all agent modules. + self.redis_address = (ip, int(port)) + self.redis_password = redis_password + self.temp_dir = temp_dir + self.log_dir = log_dir + self.node_manager_port = node_manager_port + self.object_store_name = object_store_name + self.raylet_name = raylet_name + self.ip = ray.services.get_node_ip_address() + self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) + listen_address = "[::]:0" + logger.info("Dashboard agent listen at: %s", listen_address) + self.port = self.server.add_insecure_port(listen_address) + self.aioredis_client = None + self.aiogrpc_raylet_channel = aiogrpc.insecure_channel("{}:{}".format( + self.ip, self.node_manager_port)) + self.http_session = aiohttp.ClientSession( + loop=asyncio.get_event_loop()) + + def _load_modules(self): + """Load dashboard agent modules.""" + modules = [] + for cls in self._agent_cls_list: + logger.info("Load %s: %s", + dashboard_utils.DashboardAgentModule.__name__, cls) + c = cls(self) + modules.append(c) + logger.info("Load {} modules.".format(len(modules))) + return modules + + async def run(self): + # Create an aioredis client for all modules. + self.aioredis_client = await aioredis.create_redis_pool( + address=self.redis_address, password=self.redis_password) + + # Start a grpc asyncio server. + await self.server.start() + + # Write the dashboard agent port to redis. + await self.aioredis_client.set( + "{}{}".format(dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX, + self.ip), self.port) + + async def _check_parent(): + """Check if raylet is dead.""" + curr_proc = psutil.Process() + while True: + parent = curr_proc.parent() + if parent is None or parent.pid == 1: + logger.error("raylet is dead, agent will die because " + "it fate-shares with raylet.") + sys.exit(0) + await asyncio.sleep( + dashboard_consts. + DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS) + + modules = self._load_modules() + await asyncio.gather(_check_parent(), + *(m.run(self.server) for m in modules)) + await self.server.wait_for_termination() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Dashboard agent.") + parser.add_argument( + "--redis-address", + required=True, + type=str, + help="The address to use for Redis.") + parser.add_argument( + "--node-manager-port", + required=True, + type=int, + help="The port to use for starting the node manager") + parser.add_argument( + "--object-store-name", + required=True, + type=str, + default=None, + help="The socket name of the plasma store") + parser.add_argument( + "--raylet-name", + required=True, + type=str, + default=None, + help="The socket path of the raylet process") + parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="The password to use for Redis") + parser.add_argument( + "--logging-level", + required=False, + type=lambda s: logging.getLevelName(s.upper()), + default=ray_constants.LOGGER_LEVEL, + choices=ray_constants.LOGGER_LEVEL_CHOICES, + help=ray_constants.LOGGER_LEVEL_HELP) + parser.add_argument( + "--logging-format", + required=False, + type=str, + default=ray_constants.LOGGER_FORMAT, + help=ray_constants.LOGGER_FORMAT_HELP) + parser.add_argument( + "--logging-filename", + required=False, + type=str, + default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME, + help="Specify the name of log file, " + "log to stdout if set empty, default is \"{}\".".format( + dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME)) + parser.add_argument( + "--logging-rotate-bytes", + required=False, + type=int, + default=dashboard_consts.LOGGING_ROTATE_BYTES, + help="Specify the max bytes for rotating " + "log file, default is {} bytes.".format( + dashboard_consts.LOGGING_ROTATE_BYTES)) + parser.add_argument( + "--logging-rotate-backup-count", + required=False, + type=int, + default=dashboard_consts.LOGGING_ROTATE_BACKUP_COUNT, + help="Specify the backup count of rotated log file, default is {}.". + format(dashboard_consts.LOGGING_ROTATE_BACKUP_COUNT)) + parser.add_argument( + "--log-dir", + required=False, + type=str, + default=None, + help="Specify the path of log directory.") + parser.add_argument( + "--temp-dir", + required=False, + type=str, + default=None, + help="Specify the path of the temporary directory use by Ray process.") + + args = parser.parse_args() + try: + if args.temp_dir: + temp_dir = "/" + args.temp_dir.strip("/") + else: + temp_dir = "/tmp/ray" + os.makedirs(temp_dir, exist_ok=True) + + if args.log_dir: + log_dir = args.log_dir + else: + log_dir = os.path.join(temp_dir, "session_latest/logs") + os.makedirs(log_dir, exist_ok=True) + + if args.logging_filename: + logging_handlers = [ + logging.handlers.RotatingFileHandler( + os.path.join(log_dir, args.logging_filename), + maxBytes=args.logging_rotate_bytes, + backupCount=args.logging_rotate_backup_count) + ] + else: + logging_handlers = None + logging.basicConfig( + level=args.logging_level, + format=args.logging_format, + handlers=logging_handlers) + + agent = DashboardAgent( + args.redis_address, + redis_password=args.redis_password, + temp_dir=temp_dir, + log_dir=log_dir, + node_manager_port=args.node_manager_port, + object_store_name=args.object_store_name, + raylet_name=args.raylet_name) + + loop = asyncio.get_event_loop() + loop.create_task(agent.run()) + loop.run_forever() + except Exception as e: + # Something went wrong, so push an error to all drivers. + redis_client = ray.services.create_redis_client( + args.redis_address, password=args.redis_password) + traceback_str = ray.utils.format_error_message(traceback.format_exc()) + message = ("The agent on node {} failed with the following " + "error:\n{}".format(os.uname()[1], traceback_str)) + ray.utils.push_error_to_driver_through_redis( + redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, message) + raise e diff --git a/dashboard/consts.py b/dashboard/consts.py new file mode 100644 index 0000000000000..a03cb2a909d17 --- /dev/null +++ b/dashboard/consts.py @@ -0,0 +1,16 @@ +DASHBOARD_AGENT_PORT_PREFIX = "DASHBOARD_AGENT_PORT_PREFIX:" +DASHBOARD_AGENT_LOG_FILENAME = "dashboard_agent.log" +DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS = 2 +MAX_COUNT_OF_GCS_RPC_ERROR = 10 +UPDATE_NODES_INTERVAL_SECONDS = 5 +CONNECT_GCS_INTERVAL_SECONDS = 2 +PURGE_DATA_INTERVAL_SECONDS = 60 * 10 +REDIS_KEY_DASHBOARD = "dashboard" +REDIS_KEY_GCS_SERVER_ADDRESS = "GcsServerAddress" +REPORT_METRICS_TIMEOUT_SECONDS = 2 +REPORT_METRICS_INTERVAL_SECONDS = 10 +# Named signals +SIGNAL_NODE_INFO_FETCHED = "node_info_fetched" +# Default param for RotatingFileHandler +LOGGING_ROTATE_BYTES = 100 * 1000 # maxBytes +LOGGING_ROTATE_BACKUP_COUNT = 5 # backupCount diff --git a/dashboard/dashboard.py b/dashboard/dashboard.py new file mode 100644 index 0000000000000..10705427ff8cf --- /dev/null +++ b/dashboard/dashboard.py @@ -0,0 +1,240 @@ +try: + import aiohttp.web +except ImportError: + print("The dashboard requires aiohttp to run.") + import sys + + sys.exit(1) + +import argparse +import asyncio +import errno +import logging +import logging.handlers +import os +import traceback +import uuid + +import aioredis + +import ray +import ray.new_dashboard.consts as dashboard_consts +import ray.new_dashboard.head as dashboard_head +import ray.new_dashboard.utils as dashboard_utils +import ray.ray_constants as ray_constants +import ray.services +import ray.utils + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray provides a default configuration at +# entry/init points. +logger = logging.getLogger(__name__) +routes = dashboard_utils.ClassMethodRouteTable + + +def setup_static_dir(app): + build_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "client/build") + module_name = os.path.basename(os.path.dirname(__file__)) + if not os.path.isdir(build_dir): + raise OSError( + errno.ENOENT, "Dashboard build directory not found. If installing " + "from source, please follow the additional steps " + "required to build the dashboard" + "(cd python/ray/{}/client " + "&& npm install " + "&& npm ci " + "&& npm run build)".format(module_name), build_dir) + + static_dir = os.path.join(build_dir, "static") + app.router.add_static("/static", static_dir, follow_symlinks=True) + return build_dir + + +class Dashboard: + """A dashboard process for monitoring Ray nodes. + + This dashboard is made up of a REST API which collates data published by + Reporter processes on nodes into a json structure, and a webserver + which polls said API for display purposes. + + Args: + host(str): Host address of dashboard aiohttp server. + port(int): Port number of dashboard aiohttp server. + redis_address(str): GCS address of a Ray cluster + temp_dir (str): The temporary directory used for log files and + information for this Ray session. + redis_password(str): Redis password to access GCS + """ + + def __init__(self, + host, + port, + redis_address, + temp_dir, + redis_password=None): + self.host = host + self.port = port + self.temp_dir = temp_dir + self.dashboard_id = str(uuid.uuid4()) + self.dashboard_head = dashboard_head.DashboardHead( + redis_address=redis_address, redis_password=redis_password) + + self.app = aiohttp.web.Application() + self.app.add_routes(routes=routes.routes()) + + # Setup Dashboard Routes + build_dir = setup_static_dir(self.app) + logger.info("Setup static dir for dashboard: %s", build_dir) + dashboard_utils.ClassMethodRouteTable.bind(self) + + @routes.get("/") + async def get_index(self, req) -> aiohttp.web.FileResponse: + return aiohttp.web.FileResponse( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "client/build/index.html")) + + @routes.get("/favicon.ico") + async def get_favicon(self, req) -> aiohttp.web.FileResponse: + return aiohttp.web.FileResponse( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "client/build/favicon.ico")) + + async def run(self): + coroutines = [ + self.dashboard_head.run(), + aiohttp.web._run_app(self.app, host=self.host, port=self.port) + ] + ip = ray.services.get_node_ip_address() + aioredis_client = await aioredis.create_redis_pool( + address=self.dashboard_head.redis_address, + password=self.dashboard_head.redis_password) + await aioredis_client.set(dashboard_consts.REDIS_KEY_DASHBOARD, + ip + ":" + str(self.port)) + await asyncio.gather(*coroutines) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=("Parse Redis server for the " + "dashboard to connect to.")) + parser.add_argument( + "--host", + required=True, + type=str, + help="The host to use for the HTTP server.") + parser.add_argument( + "--port", + required=True, + type=int, + help="The port to use for the HTTP server.") + parser.add_argument( + "--redis-address", + required=True, + type=str, + help="The address to use for Redis.") + parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="The password to use for Redis") + parser.add_argument( + "--logging-level", + required=False, + type=lambda s: logging.getLevelName(s.upper()), + default=ray_constants.LOGGER_LEVEL, + choices=ray_constants.LOGGER_LEVEL_CHOICES, + help=ray_constants.LOGGER_LEVEL_HELP) + parser.add_argument( + "--logging-format", + required=False, + type=str, + default=ray_constants.LOGGER_FORMAT, + help=ray_constants.LOGGER_FORMAT_HELP) + parser.add_argument( + "--logging-filename", + required=False, + type=str, + default="", + help="Specify the name of log file, " + "log to stdout if set empty, default is \"\"") + parser.add_argument( + "--logging-rotate-bytes", + required=False, + type=int, + default=dashboard_consts.LOGGING_ROTATE_BYTES, + help="Specify the max bytes for rotating " + "log file, default is {} bytes.".format( + dashboard_consts.LOGGING_ROTATE_BYTES)) + parser.add_argument( + "--logging-rotate-backup-count", + required=False, + type=int, + default=dashboard_consts.LOGGING_ROTATE_BACKUP_COUNT, + help="Specify the backup count of rotated log file, default is {}.". + format(dashboard_consts.LOGGING_ROTATE_BACKUP_COUNT)) + parser.add_argument( + "--log-dir", + required=False, + type=str, + default=None, + help="Specify the path of log directory.") + parser.add_argument( + "--temp-dir", + required=False, + type=str, + default=None, + help="Specify the path of the temporary directory use by Ray process.") + + args = parser.parse_args() + try: + if args.temp_dir: + temp_dir = "/" + args.temp_dir.strip("/") + else: + temp_dir = "/tmp/ray" + os.makedirs(temp_dir, exist_ok=True) + + if args.log_dir: + log_dir = args.log_dir + else: + log_dir = os.path.join(temp_dir, "session_latest/logs") + os.makedirs(log_dir, exist_ok=True) + + if args.logging_filename: + logging_handlers = [ + logging.handlers.RotatingFileHandler( + os.path.join(log_dir, args.logging_filename), + maxBytes=args.logging_rotate_bytes, + backupCount=args.logging_rotate_backup_count) + ] + else: + logging_handlers = None + logging.basicConfig( + level=args.logging_level, + format=args.logging_format, + handlers=logging_handlers) + + dashboard = Dashboard( + args.host, + args.port, + args.redis_address, + temp_dir, + redis_password=args.redis_password) + loop = asyncio.get_event_loop() + loop.run_until_complete(dashboard.run()) + except Exception as e: + # Something went wrong, so push an error to all drivers. + redis_client = ray.services.create_redis_client( + args.redis_address, password=args.redis_password) + traceback_str = ray.utils.format_error_message(traceback.format_exc()) + message = ("The dashboard on node {} failed with the following " + "error:\n{}".format(os.uname()[1], traceback_str)) + ray.utils.push_error_to_driver_through_redis( + redis_client, ray_constants.DASHBOARD_DIED_ERROR, message) + if isinstance(e, OSError) and e.errno == errno.ENOENT: + logger.warning(message) + else: + raise e diff --git a/dashboard/datacenter.py b/dashboard/datacenter.py new file mode 100644 index 0000000000000..65e3bb449628b --- /dev/null +++ b/dashboard/datacenter.py @@ -0,0 +1,108 @@ +import logging + +import ray.new_dashboard.consts as dashboard_consts +from ray.new_dashboard.utils import Dict, Signal + +logger = logging.getLogger(__name__) + + +class GlobalSignals: + node_info_fetched = Signal(dashboard_consts.SIGNAL_NODE_INFO_FETCHED) + + +class DataSource: + # {ip address(str): node stats(dict of GetNodeStatsReply + # in node_manager.proto)} + node_stats = Dict() + # {ip address(str): node physical stats(dict from reporter_agent.py)} + node_physical_stats = Dict() + # {actor id hex(str): actor table data(dict of ActorTableData + # in gcs.proto)} + actors = Dict() + # {ip address(str): dashboard agent grpc server port(int)} + agents = Dict() + # {ip address(str): gcs node info(dict of GcsNodeInfo in gcs.proto)} + nodes = Dict() + # {hostname(str): ip address(str)} + hostname_to_ip = Dict() + # {ip address(str): hostname(str)} + ip_to_hostname = Dict() + + +class DataOrganizer: + @staticmethod + async def purge(): + # Purge data that is out of date. + # These data sources are maintained by DashboardHead, + # we do not needs to purge them: + # * agents + # * nodes + # * hostname_to_ip + # * ip_to_hostname + logger.info("Purge data.") + valid_keys = DataSource.ip_to_hostname.keys() + for key in DataSource.node_stats.keys() - valid_keys: + DataSource.node_stats.pop(key) + + for key in DataSource.node_physical_stats.keys() - valid_keys: + DataSource.node_physical_stats.pop(key) + + @classmethod + async def get_node_actors(cls, hostname): + ip = DataSource.hostname_to_ip[hostname] + node_stats = DataSource.node_stats.get(ip, {}) + node_worker_id_set = set() + for worker_stats in node_stats.get("workersStats", []): + node_worker_id_set.add(worker_stats["workerId"]) + node_actors = {} + for actor_id, actor_table_data in DataSource.actors.items(): + if actor_table_data["workerId"] in node_worker_id_set: + node_actors[actor_id] = actor_table_data + return node_actors + + @classmethod + async def get_node_info(cls, hostname): + ip = DataSource.hostname_to_ip[hostname] + node_physical_stats = DataSource.node_physical_stats.get(ip, {}) + node_stats = DataSource.node_stats.get(ip, {}) + + # Merge coreWorkerStats (node stats) to workers (node physical stats) + workers_stats = node_stats.pop("workersStats", {}) + pid_to_worker_stats = {} + pid_to_language = {} + pid_to_job_id = {} + for stats in workers_stats: + d = pid_to_worker_stats.setdefault(stats["pid"], {}).setdefault( + stats["workerId"], stats["coreWorkerStats"]) + d["workerId"] = stats["workerId"] + pid_to_language.setdefault(stats["pid"], + stats.get("language", "PYTHON")) + pid_to_job_id.setdefault(stats["pid"], + stats["coreWorkerStats"]["jobId"]) + + for worker in node_physical_stats.get("workers", []): + worker_stats = pid_to_worker_stats.get(worker["pid"], {}) + worker["coreWorkerStats"] = list(worker_stats.values()) + worker["language"] = pid_to_language.get(worker["pid"], "") + worker["jobId"] = pid_to_job_id.get(worker["pid"], "ffff") + + # Merge node stats to node physical stats + node_info = node_physical_stats + node_info["raylet"] = node_stats + node_info["actors"] = await cls.get_node_actors(hostname) + node_info["state"] = DataSource.nodes.get(ip, {}).get("state", "DEAD") + + await GlobalSignals.node_info_fetched.send(node_info) + + return node_info + + @classmethod + async def get_all_node_summary(cls): + all_nodes_summary = [] + for hostname in DataSource.hostname_to_ip.keys(): + node_info = await cls.get_node_info(hostname) + node_info.pop("workers", None) + node_info["raylet"].pop("workersStats", None) + node_info["raylet"].pop("viewData", None) + all_nodes_summary.append(node_info) + return all_nodes_summary diff --git a/dashboard/head.py b/dashboard/head.py new file mode 100644 index 0000000000000..cb7cb21304bfc --- /dev/null +++ b/dashboard/head.py @@ -0,0 +1,170 @@ +import sys +import asyncio +import logging + +import aiohttp +import aioredis +from grpc.experimental import aio as aiogrpc + +import ray.services +import ray.new_dashboard.consts as dashboard_consts +import ray.new_dashboard.utils as dashboard_utils +from ray.core.generated import gcs_service_pb2 +from ray.core.generated import gcs_service_pb2_grpc +from ray.new_dashboard.datacenter import DataSource, DataOrganizer + +logger = logging.getLogger(__name__) +routes = dashboard_utils.ClassMethodRouteTable + +aiogrpc.init_grpc_aio() + + +def gcs_node_info_to_dict(message): + return dashboard_utils.message_to_dict( + message, {"nodeId"}, including_default_value_fields=True) + + +class DashboardHead: + def __init__(self, redis_address, redis_password): + # Scan and import head modules for collecting http routes. + self._head_cls_list = dashboard_utils.get_all_modules( + dashboard_utils.DashboardHeadModule) + ip, port = redis_address.split(":") + # NodeInfoGcsService + self._gcs_node_info_stub = None + self._gcs_rpc_error_counter = 0 + # Public attributes are accessible for all head modules. + self.redis_address = (ip, int(port)) + self.redis_password = redis_password + self.aioredis_client = None + self.aiogrpc_gcs_channel = None + self.http_session = aiohttp.ClientSession( + loop=asyncio.get_event_loop()) + self.ip = ray.services.get_node_ip_address() + + async def _get_nodes(self): + """Read the client table. + + Returns: + A list of information about the nodes in the cluster. + """ + request = gcs_service_pb2.GetAllNodeInfoRequest() + reply = await self._gcs_node_info_stub.GetAllNodeInfo( + request, timeout=2) + if reply.status.code == 0: + results = [] + node_id_set = set() + for node_info in reply.node_info_list: + if node_info.node_id in node_id_set: + continue + node_id_set.add(node_info.node_id) + node_info_dict = gcs_node_info_to_dict(node_info) + results.append(node_info_dict) + return results + else: + logger.error("Failed to GetAllNodeInfo: %s", reply.status.message) + + async def _update_nodes(self): + while True: + try: + nodes = await self._get_nodes() + self._gcs_rpc_error_counter = 0 + node_ips = [node["nodeManagerAddress"] for node in nodes] + node_hostnames = [ + node["nodeManagerHostname"] for node in nodes + ] + + agents = dict(DataSource.agents) + for node in nodes: + node_ip = node["nodeManagerAddress"] + if node_ip not in agents: + key = "{}{}".format( + dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX, + node_ip) + agent_port = await self.aioredis_client.get(key) + if agent_port: + agents[node_ip] = agent_port + for ip in agents.keys() - set(node_ips): + agents.pop(ip, None) + + DataSource.agents.reset(agents) + DataSource.nodes.reset(dict(zip(node_ips, nodes))) + DataSource.hostname_to_ip.reset( + dict(zip(node_hostnames, node_ips))) + DataSource.ip_to_hostname.reset( + dict(zip(node_ips, node_hostnames))) + except aiogrpc.AioRpcError as ex: + logger.exception(ex) + self._gcs_rpc_error_counter += 1 + if self._gcs_rpc_error_counter > \ + dashboard_consts.MAX_COUNT_OF_GCS_RPC_ERROR: + logger.error( + "Dashboard suicide, the GCS RPC error count %s > %s", + self._gcs_rpc_error_counter, + dashboard_consts.MAX_COUNT_OF_GCS_RPC_ERROR) + sys.exit(-1) + except Exception as ex: + logger.exception(ex) + finally: + await asyncio.sleep( + dashboard_consts.UPDATE_NODES_INTERVAL_SECONDS) + + def _load_modules(self): + """Load dashboard head modules.""" + modules = [] + for cls in self._head_cls_list: + logger.info("Load %s: %s", + dashboard_utils.DashboardHeadModule.__name__, cls) + c = cls(self) + dashboard_utils.ClassMethodRouteTable.bind(c) + modules.append(c) + return modules + + async def run(self): + # Create an aioredis client for all modules. + self.aioredis_client = await aioredis.create_redis_pool( + address=self.redis_address, password=self.redis_password) + # Waiting for GCS is ready. + while True: + try: + gcs_address = await self.aioredis_client.get( + dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS) + if not gcs_address: + raise Exception("GCS address not found.") + logger.info("Connect to GCS at %s", gcs_address) + channel = aiogrpc.insecure_channel(gcs_address) + except Exception as ex: + logger.error("Connect to GCS failed: %s, retry...", ex) + await asyncio.sleep( + dashboard_consts.CONNECT_GCS_INTERVAL_SECONDS) + else: + self.aiogrpc_gcs_channel = channel + break + # Create a NodeInfoGcsServiceStub. + self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub( + self.aiogrpc_gcs_channel) + + async def _async_notify(): + """Notify signals from queue.""" + while True: + co = await dashboard_utils.NotifyQueue.get() + try: + await co + except Exception as e: + logger.exception(e) + + async def _purge_data(): + """Purge data in datacenter.""" + while True: + await asyncio.sleep( + dashboard_consts.PURGE_DATA_INTERVAL_SECONDS) + try: + await DataOrganizer.purge() + except Exception as e: + logger.exception(e) + + modules = self._load_modules() + # Freeze signal after all modules loaded. + dashboard_utils.SignalManager.freeze() + await asyncio.gather(self._update_nodes(), _async_notify(), + _purge_data(), *(m.run() for m in modules)) diff --git a/dashboard/modules/__init__.py b/dashboard/modules/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/dashboard/modules/reporter/__init__.py b/dashboard/modules/reporter/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/dashboard/modules/reporter/reporter_agent.py b/dashboard/modules/reporter/reporter_agent.py new file mode 100644 index 0000000000000..40d2d24910641 --- /dev/null +++ b/dashboard/modules/reporter/reporter_agent.py @@ -0,0 +1,199 @@ +import asyncio +import datetime +import json +import logging +import os +import socket +import subprocess +import sys + +import aioredis + +import ray +import ray.gcs_utils +import ray.new_dashboard.modules.reporter.reporter_consts as reporter_consts +import ray.new_dashboard.utils as dashboard_utils +import ray.services +import ray.utils +from ray.core.generated import reporter_pb2 +from ray.core.generated import reporter_pb2_grpc +import psutil + +logger = logging.getLogger(__name__) + + +def recursive_asdict(o): + if isinstance(o, tuple) and hasattr(o, "_asdict"): + return recursive_asdict(o._asdict()) + + if isinstance(o, (tuple, list)): + L = [] + for k in o: + L.append(recursive_asdict(k)) + return L + + if isinstance(o, dict): + D = {k: recursive_asdict(v) for k, v in o.items()} + return D + + return o + + +def jsonify_asdict(o): + return json.dumps(dashboard_utils.to_google_style(recursive_asdict(o))) + + +class ReporterAgent(dashboard_utils.DashboardAgentModule, + reporter_pb2_grpc.ReporterServiceServicer): + """A monitor process for monitoring Ray nodes. + + Attributes: + dashboard_agent: The DashboardAgent object contains global config + """ + + def __init__(self, dashboard_agent): + """Initialize the reporter object.""" + super().__init__(dashboard_agent) + self._cpu_counts = (psutil.cpu_count(), + psutil.cpu_count(logical=False)) + self._ip = ray.services.get_node_ip_address() + self._hostname = socket.gethostname() + self._workers = set() + self._network_stats_hist = [(0, (0.0, 0.0))] # time, (sent, recv) + + async def GetProfilingStats(self, request, context): + pid = request.pid + duration = request.duration + profiling_file_path = os.path.join(ray.utils.get_ray_temp_dir(), + "{}_profiling.txt".format(pid)) + process = subprocess.Popen( + "sudo $(which py-spy) record -o {} -p {} -d {} -f speedscope" + .format(profiling_file_path, pid, duration), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True) + stdout, stderr = process.communicate() + if process.returncode != 0: + profiling_stats = "" + else: + with open(profiling_file_path, "r") as f: + profiling_stats = f.read() + return reporter_pb2.GetProfilingStatsReply( + profiling_stats=profiling_stats, stdout=stdout, stderr=stderr) + + @staticmethod + def _get_cpu_percent(): + return psutil.cpu_percent() + + @staticmethod + def _get_boot_time(): + return psutil.boot_time() + + @staticmethod + def _get_network_stats(): + ifaces = [ + v for k, v in psutil.net_io_counters(pernic=True).items() + if k[0] == "e" + ] + + sent = sum((iface.bytes_sent for iface in ifaces)) + recv = sum((iface.bytes_recv for iface in ifaces)) + return sent, recv + + @staticmethod + def _get_mem_usage(): + vm = psutil.virtual_memory() + return vm.total, vm.available, vm.percent + + @staticmethod + def _get_disk_usage(): + dirs = [ + os.environ["USERPROFILE"] if sys.platform == "win32" else os.sep, + ray.utils.get_user_temp_dir(), + ] + return {x: psutil.disk_usage(x) for x in dirs} + + def _get_workers(self): + curr_proc = psutil.Process() + parent = curr_proc.parent() + if parent is None or parent.pid == 1: + return [] + else: + workers = set(parent.children()) + self._workers.intersection_update(workers) + self._workers.update(workers) + self._workers.discard(curr_proc) + return [ + w.as_dict(attrs=[ + "pid", + "create_time", + "cpu_percent", + "cpu_times", + "cmdline", + "memory_info", + ]) for w in self._workers if w.status() != psutil.STATUS_ZOMBIE + ] + + @staticmethod + def _get_raylet_cmdline(): + curr_proc = psutil.Process() + parent = curr_proc.parent() + if parent.pid == 1: + return "" + else: + return parent.cmdline() + + def _get_load_avg(self): + if sys.platform == "win32": + cpu_percent = psutil.cpu_percent() + load = (cpu_percent, cpu_percent, cpu_percent) + else: + load = os.getloadavg() + per_cpu_load = tuple((round(x / self._cpu_counts[0], 2) for x in load)) + return load, per_cpu_load + + def _get_all_stats(self): + now = dashboard_utils.to_posix_time(datetime.datetime.utcnow()) + network_stats = self._get_network_stats() + + self._network_stats_hist.append((now, network_stats)) + self._network_stats_hist = self._network_stats_hist[-7:] + then, prev_network_stats = self._network_stats_hist[0] + netstats = ((network_stats[0] - prev_network_stats[0]) / (now - then), + (network_stats[1] - prev_network_stats[1]) / (now - then)) + + return { + "now": now, + "hostname": self._hostname, + "ip": self._ip, + "cpu": self._get_cpu_percent(), + "cpus": self._cpu_counts, + "mem": self._get_mem_usage(), + "workers": self._get_workers(), + "bootTime": self._get_boot_time(), + "loadAvg": self._get_load_avg(), + "disk": self._get_disk_usage(), + "net": netstats, + "cmdline": self._get_raylet_cmdline(), + } + + async def _perform_iteration(self): + """Get any changes to the log files and push updates to Redis.""" + aioredis_client = await aioredis.create_redis_pool( + address=self._dashboard_agent.redis_address, + password=self._dashboard_agent.redis_password) + + while True: + try: + stats = self._get_all_stats() + await aioredis_client.publish( + "{}{}".format(reporter_consts.REPORTER_PREFIX, + self._hostname), jsonify_asdict(stats)) + except Exception as ex: + logger.exception(ex) + await asyncio.sleep( + reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000) + + async def run(self, server): + reporter_pb2_grpc.add_ReporterServiceServicer_to_server(self, server) + await self._perform_iteration() diff --git a/dashboard/modules/reporter/reporter_consts.py b/dashboard/modules/reporter/reporter_consts.py new file mode 100644 index 0000000000000..92f5fd71495c5 --- /dev/null +++ b/dashboard/modules/reporter/reporter_consts.py @@ -0,0 +1,6 @@ +import ray.ray_constants as ray_constants + +REPORTER_PREFIX = "RAY_REPORTER:" +# The reporter will report its statistics this often (milliseconds). +REPORTER_UPDATE_INTERVAL_MS = ray_constants.env_integer( + "REPORTER_UPDATE_INTERVAL_MS", 2500) diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py new file mode 100644 index 0000000000000..959aadb411f69 --- /dev/null +++ b/dashboard/modules/reporter/reporter_head.py @@ -0,0 +1,94 @@ +import json +import logging +import uuid + +import aiohttp.web +from aioredis.pubsub import Receiver +from grpc.experimental import aio as aiogrpc + +import ray +import ray.gcs_utils +import ray.new_dashboard.modules.reporter.reporter_consts as reporter_consts +import ray.new_dashboard.utils as dashboard_utils +import ray.services +import ray.utils +from ray.core.generated import reporter_pb2 +from ray.core.generated import reporter_pb2_grpc +from ray.new_dashboard.datacenter import DataSource + +logger = logging.getLogger(__name__) +routes = dashboard_utils.ClassMethodRouteTable + + +class ReportHead(dashboard_utils.DashboardHeadModule): + def __init__(self, dashboard_head): + super().__init__(dashboard_head) + self._stubs = {} + self._profiling_stats = {} + DataSource.agents.signal.append(self._update_stubs) + + async def _update_stubs(self, change): + if change.new: + ip, port = next(iter(change.new.items())) + channel = aiogrpc.insecure_channel("{}:{}".format(ip, int(port))) + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + self._stubs[ip] = stub + if change.old: + ip, port = next(iter(change.old.items())) + self._stubs.pop(ip) + + @routes.get("/api/launch_profiling") + async def launch_profiling(self, req) -> aiohttp.web.Response: + node_id = req.query.get("node_id") + pid = int(req.query.get("pid")) + duration = int(req.query.get("duration")) + profiling_id = str(uuid.uuid4()) + reporter_stub = self._stubs[node_id] + reply = await reporter_stub.GetProfilingStats( + reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration)) + self._profiling_stats[profiling_id] = reply + return await dashboard_utils.rest_response( + success=True, + message="Profiling launched.", + profiling_id=profiling_id) + + @routes.get("/api/check_profiling_status") + async def check_profiling_status(self, req) -> aiohttp.web.Response: + profiling_id = req.query.get("profiling_id") + is_present = profiling_id in self._profiling_stats + if not is_present: + status = {"status": "pending"} + else: + reply = self._profiling_stats[profiling_id] + if reply.stderr: + status = {"status": "error", "error": reply.stderr} + else: + status = {"status": "finished"} + return await dashboard_utils.rest_response( + success=True, message="Profiling status fetched.", status=status) + + @routes.get("/api/get_profiling_info") + async def get_profiling_info(self, req) -> aiohttp.web.Response: + profiling_id = req.query.get("profiling_id") + profiling_stats = self._profiling_stats.get(profiling_id) + assert profiling_stats, "profiling not finished" + return await dashboard_utils.rest_response( + success=True, + message="Profiling info fetched.", + profiling_info=json.loads(profiling_stats.profiling_stats)) + + async def run(self): + p = self._dashboard_head.aioredis_client + mpsc = Receiver() + + reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX) + await p.psubscribe(mpsc.pattern(reporter_key)) + logger.info("Subscribed to {}".format(reporter_key)) + + async for sender, msg in mpsc.iter(): + try: + _, data = msg + data = json.loads(ray.utils.decode(data)) + DataSource.node_physical_stats[data["ip"]] = data + except Exception as ex: + logger.exception(ex) diff --git a/dashboard/utils.py b/dashboard/utils.py new file mode 100644 index 0000000000000..e8e546f6d107e --- /dev/null +++ b/dashboard/utils.py @@ -0,0 +1,340 @@ +import abc +import asyncio +import collections +import copy +import json +import datetime +import functools +import importlib +import inspect +import logging +import pkgutil +import traceback +from base64 import b64decode +from collections.abc import MutableMapping, Mapping + +import aiohttp.web +from aiohttp import hdrs +from aiohttp.frozenlist import FrozenList +import aiohttp.signals +from google.protobuf.json_format import MessageToDict +from ray.utils import binary_to_hex + +logger = logging.getLogger(__name__) + + +class DashboardAgentModule(abc.ABC): + def __init__(self, dashboard_agent): + """ + Initialize current module when DashboardAgent loading modules. + + :param dashboard_agent: The DashboardAgent instance. + """ + self._dashboard_agent = dashboard_agent + + @abc.abstractmethod + async def run(self, server): + """ + Run the module in an asyncio loop. An agent module can provide + servicers to the server. + + :param server: Asyncio GRPC server. + """ + + +class DashboardHeadModule(abc.ABC): + def __init__(self, dashboard_head): + """ + Initialize current module when DashboardHead loading modules. + + :param dashboard_head: The DashboardHead instance. + """ + self._dashboard_head = dashboard_head + + @abc.abstractmethod + async def run(self): + """ + Run the module in an asyncio loop. + """ + + +class ClassMethodRouteTable: + """A helper class to bind http route to class method.""" + + _bind_map = collections.defaultdict(dict) + _routes = aiohttp.web.RouteTableDef() + + class _BindInfo: + def __init__(self, filename, lineno, instance): + self.filename = filename + self.lineno = lineno + self.instance = instance + + @classmethod + def routes(cls): + return cls._routes + + @classmethod + def _register_route(cls, method, path, **kwargs): + def _wrapper(handler): + if path in cls._bind_map[method]: + bind_info = cls._bind_map[method][path] + raise Exception("Duplicated route path: {}, " + "previous one registered at {}:{}".format( + path, bind_info.filename, + bind_info.lineno)) + + bind_info = cls._BindInfo(handler.__code__.co_filename, + handler.__code__.co_firstlineno, None) + + @functools.wraps(handler) + async def _handler_route(*args, **kwargs): + if len(args) and args[0] == bind_info.instance: + args = args[1:] + try: + return await handler(bind_info.instance, *args, **kwargs) + except Exception: + return await rest_response( + success=False, message=traceback.format_exc()) + + cls._bind_map[method][path] = bind_info + _handler_route.__route_method__ = method + _handler_route.__route_path__ = path + return cls._routes.route(method, path, **kwargs)(_handler_route) + + return _wrapper + + @classmethod + def head(cls, path, **kwargs): + return cls._register_route(hdrs.METH_HEAD, path, **kwargs) + + @classmethod + def get(cls, path, **kwargs): + return cls._register_route(hdrs.METH_GET, path, **kwargs) + + @classmethod + def post(cls, path, **kwargs): + return cls._register_route(hdrs.METH_POST, path, **kwargs) + + @classmethod + def put(cls, path, **kwargs): + return cls._register_route(hdrs.METH_PUT, path, **kwargs) + + @classmethod + def patch(cls, path, **kwargs): + return cls._register_route(hdrs.METH_PATCH, path, **kwargs) + + @classmethod + def delete(cls, path, **kwargs): + return cls._register_route(hdrs.METH_DELETE, path, **kwargs) + + @classmethod + def view(cls, path, **kwargs): + return cls._register_route(hdrs.METH_ANY, path, **kwargs) + + @classmethod + def bind(cls, instance): + def predicate(o): + if inspect.ismethod(o): + return hasattr(o, "__route_method__") and hasattr( + o, "__route_path__") + return False + + handler_routes = inspect.getmembers(instance, predicate) + for _, h in handler_routes: + cls._bind_map[h.__func__.__route_method__][ + h.__func__.__route_path__].instance = instance + + +def get_all_modules(module_type): + logger.info("Get all modules by type: {}".format(module_type.__name__)) + import ray.new_dashboard.modules + + for module_loader, name, ispkg in pkgutil.walk_packages( + ray.new_dashboard.modules.__path__, + ray.new_dashboard.modules.__name__ + "."): + importlib.import_module(name) + return module_type.__subclasses__() + + +def to_posix_time(dt): + return (dt - datetime.datetime(1970, 1, 1)).total_seconds() + + +class CustomEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, bytes): + return binary_to_hex(obj) + # Let the base class default method raise the TypeError + return json.JSONEncoder.default(self, obj) + + +async def rest_response(success, message, **kwargs) -> aiohttp.web.Response: + return aiohttp.web.json_response( + { + "result": success, + "msg": message, + "data": to_google_style(kwargs) + }, + dumps=functools.partial(json.dumps, cls=CustomEncoder)) + + +def to_camel_case(snake_str): + """Convert a snake str to camel case.""" + components = snake_str.split("_") + # We capitalize the first letter of each component except the first one + # with the 'title' method and join them together. + return components[0] + "".join(x.title() for x in components[1:]) + + +def to_google_style(d): + """Recursive convert all keys in dict to google style.""" + new_dict = {} + for k, v in d.items(): + if isinstance(v, dict): + new_dict[to_camel_case(k)] = to_google_style(v) + elif isinstance(v, list): + new_list = [] + for i in v: + if isinstance(i, dict): + new_list.append(to_google_style(i)) + else: + new_list.append(i) + new_dict[to_camel_case(k)] = new_list + else: + new_dict[to_camel_case(k)] = v + return new_dict + + +def message_to_dict(message, decode_keys=None, **kwargs): + """Convert protobuf message to Python dict.""" + + def _decode_keys(d): + for k, v in d.items(): + if isinstance(v, dict): + d[k] = _decode_keys(v) + if isinstance(v, list): + new_list = [] + for i in v: + if isinstance(i, dict): + new_list.append(_decode_keys(i)) + else: + new_list.append(i) + d[k] = new_list + else: + if k in decode_keys: + d[k] = binary_to_hex(b64decode(v)) + else: + d[k] = v + return d + + if decode_keys: + return _decode_keys( + MessageToDict(message, use_integers_for_enums=False, **kwargs)) + else: + return MessageToDict(message, use_integers_for_enums=False, **kwargs) + + +class SignalManager: + _signals = FrozenList() + + @classmethod + def register(cls, sig): + cls._signals.append(sig) + + @classmethod + def freeze(cls): + cls._signals.freeze() + for sig in cls._signals: + sig.freeze() + + +class Signal(aiohttp.signals.Signal): + __slots__ = () + + def __init__(self, owner): + super().__init__(owner) + SignalManager.register(self) + + +class Bunch(dict): + """A dict with attribute-access.""" + + def __getattr__(self, key): + try: + return self.__getitem__(key) + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + self.__setitem__(key, value) + + +class Change: + """Notify change object.""" + + def __init__(self, owner=None, old=None, new=None): + self.owner = owner + self.old = old + self.new = new + + def __str__(self): + return "Change(owner: {}, old: {}, new: {}".format( + self.owner, self.old, self.new) + + +class NotifyQueue: + """Asyncio notify queue for Dict signal.""" + + _queue = asyncio.Queue() + + @classmethod + def put(cls, co): + cls._queue.put_nowait(co) + + @classmethod + async def get(cls): + return await cls._queue.get() + + +class Dict(MutableMapping): + """A simple descriptor for dict type to notify data changes. + + :note: Only the first level data report change. + """ + + def __init__(self, *args, **kwargs): + self._data = dict(*args, **kwargs) + self.signal = Signal(self) + + def __setitem__(self, key, value): + old = self._data.pop(key, None) + self._data[key] = value + if len(self.signal) and old != value: + if old is None: + co = self.signal.send(Change(owner=self, new={key: value})) + else: + co = self.signal.send( + Change(owner=self, old={key: old}, new={key: value})) + NotifyQueue.put(co) + + def __getitem__(self, item): + return copy.deepcopy(self._data[item]) + + def __delitem__(self, key): + old = self._data.pop(key, None) + if len(self.signal) and old is not None: + co = self.signal.send(Change(owner=self, old={key: old})) + NotifyQueue.put(co) + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(copy.deepcopy(self._data)) + + def reset(self, d): + assert isinstance(d, Mapping) + for key in self._data.keys() - d.keys(): + self.pop(key) + self.update(d) diff --git a/python/ray/new_dashboard b/python/ray/new_dashboard new file mode 120000 index 0000000000000..2551d65a9ca67 --- /dev/null +++ b/python/ray/new_dashboard @@ -0,0 +1 @@ +../../dashboard \ No newline at end of file diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 93a1249234452..4988a19277598 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -123,6 +123,7 @@ def to_memory_units(memory_bytes, round_up): MONITOR_DIED_ERROR = "monitor_died" LOG_MONITOR_DIED_ERROR = "log_monitor_died" REPORTER_DIED_ERROR = "reporter_died" +DASHBOARD_AGENT_DIED_ERROR = "dashboard_agent_died" DASHBOARD_DIED_ERROR = "dashboard_died" RAYLET_CONNECTION_ERROR = "raylet_connection_error" diff --git a/python/setup.py b/python/setup.py index bace8d4f1c0d7..60f66d02f4d49 100644 --- a/python/setup.py +++ b/python/setup.py @@ -301,13 +301,14 @@ def find_version(*filepath): install_requires = [ "aiohttp", + "aioredis", "click >= 7.0", "colorama", "colorful", "filelock", "google", "gpustat", - "grpcio", + "grpcio >= 1.28.1", "jsonschema", "msgpack >= 0.6.0, < 2.0.0", "numpy >= 1.16",